summary refs log tree commit diff
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2021-01-14 15:02:42 +0000
committerBen Banfield-Zanin <benbz@matrix.org>2021-01-14 15:02:42 +0000
commit3cbcf6b923d6339c5c19c03e0edc4f81c0199867 (patch)
tree681e63f940bc91db06f4604cf20a773e9c8661ef
parentMerge remote-tracking branch 'origin/release-v1.21.2' into toml/keycloak_hints (diff)
parentMove removal warning up changelog (diff)
downloadsynapse-3cbcf6b923d6339c5c19c03e0edc4f81c0199867.tar.xz
Merge remote-tracking branch 'origin/release-v1.25.0' into toml/keycloak_hints
-rwxr-xr-x.buildkite/scripts/test_old_deps.sh2
-rw-r--r--.buildkite/test_db.dbbin19279872 -> 19296256 bytes
-rw-r--r--.buildkite/worker-blacklist31
-rw-r--r--.circleci/config.yml63
-rw-r--r--.gitignore1
-rw-r--r--CHANGES.md864
-rw-r--r--CONTRIBUTING.md24
-rw-r--r--INSTALL.md272
-rw-r--r--README.rst67
-rw-r--r--UPGRADE.rst128
-rw-r--r--changelog.d/8530.bugfix1
-rw-r--r--contrib/grafana/README.md2
-rw-r--r--contrib/prometheus/README.md10
-rw-r--r--contrib/prometheus/consoles/synapse.html101
-rw-r--r--contrib/prometheus/synapse-v2.rules18
-rw-r--r--debian/changelog48
-rw-r--r--debian/control6
-rwxr-xr-xdemo/start.sh2
-rw-r--r--docker/Dockerfile5
-rw-r--r--docker/Dockerfile-dhvirtualenv13
-rw-r--r--docker/README.md16
-rw-r--r--docker/conf/homeserver.yaml2
-rwxr-xr-xdocker/start.py45
-rw-r--r--docs/admin_api/event_reports.md172
-rw-r--r--docs/admin_api/event_reports.rst129
-rw-r--r--docs/admin_api/media_admin_api.md166
-rw-r--r--docs/admin_api/purge_remote_media.rst20
-rw-r--r--docs/admin_api/purge_room.md9
-rw-r--r--docs/admin_api/register_api.rst4
-rw-r--r--docs/admin_api/rooms.md92
-rw-r--r--docs/admin_api/shutdown_room.md7
-rw-r--r--docs/admin_api/statistics.md83
-rw-r--r--docs/admin_api/user_admin_api.rst220
-rw-r--r--docs/code_style.md2
-rw-r--r--docs/dev/cas.md6
-rw-r--r--docs/manhole.md53
-rw-r--r--docs/message_retention_policies.md34
-rw-r--r--docs/metrics-howto.md74
-rw-r--r--docs/openid.md76
-rw-r--r--docs/password_auth_providers.md1
-rw-r--r--docs/reverse_proxy.md2
-rw-r--r--docs/sample_config.yaml350
-rw-r--r--docs/sample_log_config.yaml6
-rw-r--r--docs/spam_checker.md28
-rw-r--r--docs/sphinx/README.rst1
-rw-r--r--docs/sphinx/conf.py271
-rw-r--r--docs/sphinx/index.rst20
-rw-r--r--docs/sphinx/modules.rst7
-rw-r--r--docs/sphinx/synapse.api.auth.rst7
-rw-r--r--docs/sphinx/synapse.api.constants.rst7
-rw-r--r--docs/sphinx/synapse.api.dbobjects.rst7
-rw-r--r--docs/sphinx/synapse.api.errors.rst7
-rw-r--r--docs/sphinx/synapse.api.event_stream.rst7
-rw-r--r--docs/sphinx/synapse.api.events.factory.rst7
-rw-r--r--docs/sphinx/synapse.api.events.room.rst7
-rw-r--r--docs/sphinx/synapse.api.events.rst18
-rw-r--r--docs/sphinx/synapse.api.handlers.events.rst7
-rw-r--r--docs/sphinx/synapse.api.handlers.factory.rst7
-rw-r--r--docs/sphinx/synapse.api.handlers.federation.rst7
-rw-r--r--docs/sphinx/synapse.api.handlers.register.rst7
-rw-r--r--docs/sphinx/synapse.api.handlers.room.rst7
-rw-r--r--docs/sphinx/synapse.api.handlers.rst21
-rw-r--r--docs/sphinx/synapse.api.notifier.rst7
-rw-r--r--docs/sphinx/synapse.api.register_events.rst7
-rw-r--r--docs/sphinx/synapse.api.room_events.rst7
-rw-r--r--docs/sphinx/synapse.api.rst30
-rw-r--r--docs/sphinx/synapse.api.server.rst7
-rw-r--r--docs/sphinx/synapse.api.storage.rst7
-rw-r--r--docs/sphinx/synapse.api.stream.rst7
-rw-r--r--docs/sphinx/synapse.api.streams.event.rst7
-rw-r--r--docs/sphinx/synapse.api.streams.rst17
-rw-r--r--docs/sphinx/synapse.app.homeserver.rst7
-rw-r--r--docs/sphinx/synapse.app.rst17
-rw-r--r--docs/sphinx/synapse.db.rst10
-rw-r--r--docs/sphinx/synapse.federation.handler.rst7
-rw-r--r--docs/sphinx/synapse.federation.messaging.rst7
-rw-r--r--docs/sphinx/synapse.federation.pdu_codec.rst7
-rw-r--r--docs/sphinx/synapse.federation.persistence.rst7
-rw-r--r--docs/sphinx/synapse.federation.replication.rst7
-rw-r--r--docs/sphinx/synapse.federation.rst22
-rw-r--r--docs/sphinx/synapse.federation.transport.rst7
-rw-r--r--docs/sphinx/synapse.federation.units.rst7
-rw-r--r--docs/sphinx/synapse.persistence.rst19
-rw-r--r--docs/sphinx/synapse.persistence.service.rst7
-rw-r--r--docs/sphinx/synapse.persistence.tables.rst7
-rw-r--r--docs/sphinx/synapse.persistence.transactions.rst7
-rw-r--r--docs/sphinx/synapse.rest.base.rst7
-rw-r--r--docs/sphinx/synapse.rest.events.rst7
-rw-r--r--docs/sphinx/synapse.rest.register.rst7
-rw-r--r--docs/sphinx/synapse.rest.room.rst7
-rw-r--r--docs/sphinx/synapse.rest.rst20
-rw-r--r--docs/sphinx/synapse.rst30
-rw-r--r--docs/sphinx/synapse.server.rst7
-rw-r--r--docs/sphinx/synapse.state.rst7
-rw-r--r--docs/sphinx/synapse.util.async.rst7
-rw-r--r--docs/sphinx/synapse.util.dbutils.rst7
-rw-r--r--docs/sphinx/synapse.util.http.rst7
-rw-r--r--docs/sphinx/synapse.util.lockutils.rst7
-rw-r--r--docs/sphinx/synapse.util.logutils.rst7
-rw-r--r--docs/sphinx/synapse.util.rst21
-rw-r--r--docs/sphinx/synapse.util.stringutils.rst7
-rw-r--r--docs/sso_mapping_providers.md45
-rw-r--r--docs/structured_logging.md164
-rw-r--r--docs/systemd-with-workers/README.md4
-rw-r--r--docs/tcp_replication.md13
-rw-r--r--docs/turn-howto.md131
-rw-r--r--docs/workers.md45
-rw-r--r--mypy.ini60
-rw-r--r--pyproject.toml2
-rwxr-xr-xscripts-dev/lint.sh94
-rw-r--r--scripts-dev/mypy_synapse_plugin.py40
-rwxr-xr-xscripts-dev/sign_json127
-rw-r--r--scripts-dev/sphinx_api_docs.sh1
-rwxr-xr-xscripts/generate_log_config4
-rwxr-xr-xscripts/synapse_port_db166
-rw-r--r--setup.cfg5
-rwxr-xr-xsetup.py7
-rw-r--r--stubs/sortedcontainers/__init__.pyi11
-rw-r--r--stubs/sortedcontainers/sortedlist.pyi177
-rw-r--r--stubs/txredisapi.pyi20
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/_scripts/register_new_matrix_user.py2
-rw-r--r--synapse/api/auth.py132
-rw-r--r--synapse/api/auth_blocking.py40
-rw-r--r--synapse/api/constants.py14
-rw-r--r--synapse/app/_base.py30
-rw-r--r--synapse/app/admin_cmd.py3
-rw-r--r--synapse/app/generic_worker.py44
-rw-r--r--synapse/app/homeserver.py233
-rw-r--r--synapse/app/phone_stats_home.py190
-rw-r--r--synapse/appservice/__init__.py180
-rw-r--r--synapse/appservice/api.py29
-rw-r--r--synapse/appservice/scheduler.py64
-rw-r--r--synapse/config/_base.py14
-rw-r--r--synapse/config/_base.pyi11
-rw-r--r--synapse/config/_util.py35
-rw-r--r--synapse/config/appservice.py3
-rw-r--r--synapse/config/auth.py (renamed from synapse/config/password.py)26
-rw-r--r--synapse/config/cas.py46
-rw-r--r--synapse/config/emailconfig.py27
-rw-r--r--synapse/config/federation.py43
-rw-r--r--synapse/config/groups.py2
-rw-r--r--synapse/config/homeserver.py4
-rw-r--r--synapse/config/jwt_config.py2
-rw-r--r--synapse/config/logger.py98
-rw-r--r--synapse/config/oidc_config.py23
-rw-r--r--synapse/config/password_auth_providers.py5
-rw-r--r--synapse/config/push.py48
-rw-r--r--synapse/config/registration.py7
-rw-r--r--synapse/config/repository.py30
-rw-r--r--synapse/config/room_directory.py4
-rw-r--r--synapse/config/saml2_config.py121
-rw-r--r--synapse/config/server.py94
-rw-r--r--synapse/config/spam_checker.py9
-rw-r--r--synapse/config/sso.py7
-rw-r--r--synapse/config/third_party_event_rules.py4
-rw-r--r--synapse/config/tls.py18
-rw-r--r--synapse/config/tracer.py2
-rw-r--r--synapse/config/workers.py28
-rw-r--r--synapse/crypto/context_factory.py4
-rw-r--r--synapse/crypto/event_signing.py29
-rw-r--r--synapse/crypto/keyring.py210
-rw-r--r--synapse/event_auth.py2
-rw-r--r--synapse/events/__init__.py18
-rw-r--r--synapse/events/builder.py21
-rw-r--r--synapse/events/spamcheck.py63
-rw-r--r--synapse/events/third_party_rules.py75
-rw-r--r--synapse/events/utils.py7
-rw-r--r--synapse/events/validator.py30
-rw-r--r--synapse/federation/federation_base.py7
-rw-r--r--synapse/federation/federation_server.py46
-rw-r--r--synapse/federation/send_queue.py2
-rw-r--r--synapse/federation/sender/__init__.py11
-rw-r--r--synapse/federation/sender/per_destination_queue.py2
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server.py74
-rw-r--r--synapse/groups/attestations.py2
-rw-r--r--synapse/groups/groups_server.py4
-rw-r--r--synapse/handlers/__init__.py33
-rw-r--r--synapse/handlers/_base.py28
-rw-r--r--synapse/handlers/account_data.py14
-rw-r--r--synapse/handlers/account_validity.py42
-rw-r--r--synapse/handlers/admin.py67
-rw-r--r--synapse/handlers/appservice.py222
-rw-r--r--synapse/handlers/auth.py587
-rw-r--r--synapse/handlers/cas_handler.py306
-rw-r--r--synapse/handlers/deactivate_account.py18
-rw-r--r--synapse/handlers/device.py171
-rw-r--r--synapse/handlers/devicemessage.py25
-rw-r--r--synapse/handlers/directory.py18
-rw-r--r--synapse/handlers/e2e_keys.py43
-rw-r--r--synapse/handlers/federation.py116
-rw-r--r--synapse/handlers/groups_local.py9
-rw-r--r--synapse/handlers/identity.py14
-rw-r--r--synapse/handlers/initial_sync.py22
-rw-r--r--synapse/handlers/message.py385
-rw-r--r--synapse/handlers/oidc_handler.py361
-rw-r--r--synapse/handlers/pagination.py21
-rw-r--r--synapse/handlers/password_policy.py10
-rw-r--r--synapse/handlers/presence.py21
-rw-r--r--synapse/handlers/profile.py136
-rw-r--r--synapse/handlers/read_marker.py10
-rw-r--r--synapse/handlers/receipts.py71
-rw-r--r--synapse/handlers/register.py254
-rw-r--r--synapse/handlers/room.py78
-rw-r--r--synapse/handlers/room_list.py101
-rw-r--r--synapse/handlers/room_member.py210
-rw-r--r--synapse/handlers/saml_handler.py268
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/sso.py571
-rw-r--r--synapse/handlers/state_deltas.py2
-rw-r--r--synapse/handlers/stats.py2
-rw-r--r--synapse/handlers/sync.py51
-rw-r--r--synapse/handlers/typing.py54
-rw-r--r--synapse/handlers/ui_auth/checkers.py2
-rw-r--r--synapse/handlers/user_directory.py87
-rw-r--r--synapse/http/client.py284
-rw-r--r--synapse/http/federation/matrix_federation_agent.py116
-rw-r--r--synapse/http/federation/well_known_resolver.py43
-rw-r--r--synapse/http/matrixfederationclient.py393
-rw-r--r--synapse/http/proxyagent.py16
-rw-r--r--synapse/http/request_metrics.py2
-rw-r--r--synapse/http/server.py56
-rw-r--r--synapse/http/servlet.py3
-rw-r--r--synapse/http/site.py53
-rw-r--r--synapse/logging/__init__.py20
-rw-r--r--synapse/logging/_remote.py241
-rw-r--r--synapse/logging/_structured.py337
-rw-r--r--synapse/logging/_terse_json.py365
-rw-r--r--synapse/logging/context.py24
-rw-r--r--synapse/logging/filter.py33
-rw-r--r--synapse/logging/opentracing.py10
-rw-r--r--synapse/metrics/__init__.py10
-rw-r--r--synapse/metrics/background_process_metrics.py29
-rw-r--r--synapse/module_api/__init__.py127
-rw-r--r--synapse/notifier.py119
-rw-r--r--synapse/push/__init__.py108
-rw-r--r--synapse/push/action_generator.py15
-rw-r--r--synapse/push/baserules.py49
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py135
-rw-r--r--synapse/push/clientformat.py23
-rw-r--r--synapse/push/emailpusher.py108
-rw-r--r--synapse/push/httppusher.py134
-rw-r--r--synapse/push/mailer.py186
-rw-r--r--synapse/push/presentable_names.py48
-rw-r--r--synapse/push/push_rule_evaluator.py48
-rw-r--r--synapse/push/push_tools.py21
-rw-r--r--synapse/push/pusher.py34
-rw-r--r--synapse/push/pusherpool.py192
-rw-r--r--synapse/python_dependencies.py15
-rw-r--r--synapse/replication/http/_base.py49
-rw-r--r--synapse/replication/http/federation.py12
-rw-r--r--synapse/replication/http/login.py12
-rw-r--r--synapse/replication/http/membership.py74
-rw-r--r--synapse/replication/http/send_event.py19
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py20
-rw-r--r--synapse/replication/slave/storage/client_ips.py10
-rw-r--r--synapse/replication/slave/storage/pushers.py17
-rw-r--r--synapse/replication/tcp/client.py24
-rw-r--r--synapse/replication/tcp/commands.py36
-rw-r--r--synapse/replication/tcp/handler.py30
-rw-r--r--synapse/replication/tcp/protocol.py13
-rw-r--r--synapse/replication/tcp/redis.py44
-rw-r--r--synapse/replication/tcp/resource.py57
-rw-r--r--synapse/replication/tcp/streams/_base.py11
-rw-r--r--synapse/replication/tcp/streams/events.py27
-rw-r--r--synapse/res/templates/notif.html56
-rw-r--r--synapse/res/templates/notif.txt24
-rw-r--r--synapse/res/templates/notif_mail.html26
-rw-r--r--synapse/res/templates/notif_mail.txt6
-rw-r--r--synapse/res/templates/room.html26
-rw-r--r--synapse/res/templates/room.txt12
-rw-r--r--synapse/res/username_picker/index.html19
-rw-r--r--synapse/res/username_picker/script.js95
-rw-r--r--synapse/res/username_picker/style.css27
-rw-r--r--synapse/rest/admin/__init__.py33
-rw-r--r--synapse/rest/admin/_base.py22
-rw-r--r--synapse/rest/admin/devices.py2
-rw-r--r--synapse/rest/admin/event_reports.py46
-rw-r--r--synapse/rest/admin/groups.py7
-rw-r--r--synapse/rest/admin/media.py96
-rw-r--r--synapse/rest/admin/rooms.py217
-rw-r--r--synapse/rest/admin/statistics.py122
-rw-r--r--synapse/rest/admin/users.py209
-rw-r--r--synapse/rest/client/v1/directory.py25
-rw-r--r--synapse/rest/client/v1/events.py3
-rw-r--r--synapse/rest/client/v1/login.py165
-rw-r--r--synapse/rest/client/v1/logout.py6
-rw-r--r--synapse/rest/client/v1/presence.py3
-rw-r--r--synapse/rest/client/v1/profile.py10
-rw-r--r--synapse/rest/client/v1/push_rule.py3
-rw-r--r--synapse/rest/client/v1/pusher.py24
-rw-r--r--synapse/rest/client/v1/room.py49
-rw-r--r--synapse/rest/client/v1/voip.py3
-rw-r--r--synapse/rest/client/v2_alpha/account.py48
-rw-r--r--synapse/rest/client/v2_alpha/auth.py3
-rw-r--r--synapse/rest/client/v2_alpha/devices.py134
-rw-r--r--synapse/rest/client/v2_alpha/groups.py48
-rw-r--r--synapse/rest/client/v2_alpha/keys.py37
-rw-r--r--synapse/rest/client/v2_alpha/register.py38
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py3
-rw-r--r--synapse/rest/client/v2_alpha/sync.py2
-rw-r--r--synapse/rest/key/v2/local_key_resource.py2
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py9
-rw-r--r--synapse/rest/media/v1/_base.py11
-rw-r--r--synapse/rest/media/v1/filepath.py17
-rw-r--r--synapse/rest/media/v1/media_repository.py239
-rw-r--r--synapse/rest/media/v1/media_storage.py30
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py6
-rw-r--r--synapse/rest/media/v1/storage_provider.py16
-rw-r--r--synapse/rest/media/v1/upload_resource.py2
-rw-r--r--synapse/rest/synapse/client/pick_username.py88
-rw-r--r--synapse/server.py150
-rw-r--r--synapse/server_notices/consent_server_notices.py2
-rw-r--r--synapse/server_notices/server_notices_manager.py15
-rw-r--r--synapse/spam_checker_api/__init__.py43
-rw-r--r--synapse/state/__init__.py8
-rw-r--r--synapse/state/v1.py2
-rw-r--r--synapse/state/v2.py94
-rw-r--r--synapse/static/client/login/js/login.js2
-rw-r--r--synapse/storage/__init__.py9
-rw-r--r--synapse/storage/_base.py48
-rw-r--r--synapse/storage/background_updates.py111
-rw-r--r--synapse/storage/database.py197
-rw-r--r--synapse/storage/databases/__init__.py2
-rw-r--r--synapse/storage/databases/main/__init__.py203
-rw-r--r--synapse/storage/databases/main/account_data.py9
-rw-r--r--synapse/storage/databases/main/appservice.py171
-rw-r--r--synapse/storage/databases/main/censor_events.py21
-rw-r--r--synapse/storage/databases/main/client_ips.py124
-rw-r--r--synapse/storage/databases/main/devices.py321
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py129
-rw-r--r--synapse/storage/databases/main/event_federation.py146
-rw-r--r--synapse/storage/databases/main/event_push_actions.py257
-rw-r--r--synapse/storage/databases/main/events.py73
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py7
-rw-r--r--synapse/storage/databases/main/events_worker.py231
-rw-r--r--synapse/storage/databases/main/keys.py15
-rw-r--r--synapse/storage/databases/main/media_repository.py131
-rw-r--r--synapse/storage/databases/main/metrics.py213
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py110
-rw-r--r--synapse/storage/databases/main/profile.py92
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/pusher.py95
-rw-r--r--synapse/storage/databases/main/receipts.py71
-rw-r--r--synapse/storage/databases/main/registration.py534
-rw-r--r--synapse/storage/databases/main/room.py310
-rw-r--r--synapse/storage/databases/main/roommember.py68
-rw-r--r--synapse/storage/databases/main/schema/delta/20/pushers.py19
-rw-r--r--synapse/storage/databases/main/schema/delta/25/fts.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/27/ts.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/30/as_users.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/31/pushers.py19
-rw-r--r--synapse/storage/databases/main/schema/delta/31/search_update.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/33/event_fields.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/33/remote_media_ts.py5
-rw-r--r--synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py7
-rw-r--r--synapse/storage/databases/main/schema/delta/57/local_current_membership.py1
-rw-r--r--synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres12
-rw-r--r--synapse/storage/databases/main/schema/delta/58/11dehydration.sql20
-rw-r--r--synapse/storage/databases/main/schema/delta/58/11fallback.sql24
-rw-r--r--synapse/storage/databases/main/schema/delta/58/12room_stats.sql4
-rw-r--r--synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres25
-rw-r--r--synapse/storage/databases/main/schema/delta/58/19txn_id.sql40
-rw-r--r--synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql1
-rw-r--r--synapse/storage/databases/main/schema/delta/58/22puppet_token.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql2
-rw-r--r--synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql19
-rw-r--r--synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/27local_invites.sql18
-rw-r--r--synapse/storage/databases/main/stats.py127
-rw-r--r--synapse/storage/databases/main/stream.py295
-rw-r--r--synapse/storage/databases/main/transactions.py118
-rw-r--r--synapse/storage/databases/main/ui_auth.py6
-rw-r--r--synapse/storage/databases/main/user_directory.py76
-rw-r--r--synapse/storage/keys.py5
-rw-r--r--synapse/storage/persist_events.py298
-rw-r--r--synapse/storage/prepare_database.py137
-rw-r--r--synapse/storage/purge_events.py11
-rw-r--r--synapse/storage/relations.py44
-rw-r--r--synapse/storage/state.py35
-rw-r--r--synapse/storage/types.py8
-rw-r--r--synapse/storage/util/id_generators.py24
-rw-r--r--synapse/storage/util/sequence.py17
-rw-r--r--synapse/types.py163
-rw-r--r--synapse/util/__init__.py24
-rw-r--r--synapse/util/async_helpers.py8
-rw-r--r--synapse/util/caches/__init__.py13
-rw-r--r--synapse/util/caches/deferred_cache.py342
-rw-r--r--synapse/util/caches/descriptors.py521
-rw-r--r--synapse/util/caches/dictionary_cache.py29
-rw-r--r--synapse/util/caches/lrucache.py138
-rw-r--r--synapse/util/caches/response_cache.py50
-rw-r--r--synapse/util/caches/ttlcache.py2
-rw-r--r--synapse/util/distributor.py7
-rw-r--r--synapse/util/frozenutils.py22
-rw-r--r--synapse/util/module_loader.py64
-rw-r--r--synapse/util/retryutils.py2
-rw-r--r--synapse/visibility.py57
-rwxr-xr-xsynctl7
-rw-r--r--synmark/__init__.py46
-rw-r--r--synmark/__main__.py6
-rw-r--r--synmark/suites/logging.py60
-rw-r--r--sytest-blacklist3
-rw-r--r--tests/api/test_auth.py47
-rw-r--r--tests/api/test_filtering.py19
-rw-r--r--tests/api/test_ratelimiting.py4
-rw-r--r--tests/app/test_frontend_proxy.py11
-rw-r--r--tests/app/test_openid_listener.py19
-rw-r--r--tests/appservice/test_appservice.py1
-rw-r--r--tests/appservice/test_scheduler.py118
-rw-r--r--tests/crypto/test_keyring.py14
-rw-r--r--tests/federation/test_complexity.py6
-rw-r--r--tests/federation/test_federation_server.py9
-rw-r--r--tests/federation/transport/__init__.py0
-rw-r--r--tests/federation/transport/test_server.py41
-rw-r--r--tests/handlers/test_admin.py2
-rw-r--r--tests/handlers/test_appservice.py15
-rw-r--r--tests/handlers/test_auth.py17
-rw-r--r--tests/handlers/test_cas.py121
-rw-r--r--tests/handlers/test_device.py84
-rw-r--r--tests/handlers/test_directory.py30
-rw-r--r--tests/handlers/test_e2e_keys.py87
-rw-r--r--tests/handlers/test_e2e_room_keys.py2
-rw-r--r--tests/handlers/test_federation.py11
-rw-r--r--tests/handlers/test_message.py212
-rw-r--r--tests/handlers/test_oidc.py585
-rw-r--r--tests/handlers/test_password_providers.py603
-rw-r--r--tests/handlers/test_presence.py6
-rw-r--r--tests/handlers/test_profile.py9
-rw-r--r--tests/handlers/test_register.py24
-rw-r--r--tests/handlers/test_saml.py265
-rw-r--r--tests/handlers/test_sync.py14
-rw-r--r--tests/handlers/test_typing.py68
-rw-r--r--tests/handlers/test_user_directory.py50
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py30
-rw-r--r--tests/http/test_additional_resource.py15
-rw-r--r--tests/http/test_proxyagent.py130
-rw-r--r--tests/logging/__init__.py34
-rw-r--r--tests/logging/test_remote_handler.py169
-rw-r--r--tests/logging/test_structured.py214
-rw-r--r--tests/logging/test_terse_json.py265
-rw-r--r--tests/module_api/test_api.py154
-rw-r--r--tests/push/test_email.py52
-rw-r--r--tests/push/test_http.py285
-rw-r--r--tests/replication/_base.py264
-rw-r--r--tests/replication/tcp/streams/test_events.py2
-rw-r--r--tests/replication/test_auth.py117
-rw-r--r--tests/replication/test_client_reader_shard.py59
-rw-r--r--tests/replication/test_federation_ack.py2
-rw-r--r--tests/replication/test_federation_sender_shard.py14
-rw-r--r--tests/replication/test_multi_media_repo.py279
-rw-r--r--tests/replication/test_pusher_shard.py16
-rw-r--r--tests/replication/test_sharded_event_persister.py333
-rw-r--r--tests/rest/admin/test_admin.py69
-rw-r--r--tests/rest/admin/test_device.py158
-rw-r--r--tests/rest/admin/test_event_reports.py241
-rw-r--r--tests/rest/admin/test_media.py535
-rw-r--r--tests/rest/admin/test_room.py425
-rw-r--r--tests/rest/admin/test_statistics.py452
-rw-r--r--tests/rest/admin/test_user.py1155
-rw-r--r--tests/rest/client/test_consent.py32
-rw-r--r--tests/rest/client/test_ephemeral_message.py3
-rw-r--r--tests/rest/client/test_identity.py8
-rw-r--r--tests/rest/client/test_redactions.py10
-rw-r--r--tests/rest/client/test_retention.py3
-rw-r--r--tests/rest/client/test_shadow_banned.py26
-rw-r--r--tests/rest/client/test_third_party_rules.py185
-rw-r--r--tests/rest/client/third_party_rules.py79
-rw-r--r--tests/rest/client/v1/test_directory.py23
-rw-r--r--tests/rest/client/v1/test_events.py14
-rw-r--r--tests/rest/client/v1/test_login.py132
-rw-r--r--tests/rest/client/v1/test_presence.py21
-rw-r--r--tests/rest/client/v1/test_profile.py25
-rw-r--r--tests/rest/client/v1/test_push_rule_attrs.py99
-rw-r--r--tests/rest/client/v1/test_rooms.py355
-rw-r--r--tests/rest/client/v1/test_typing.py16
-rw-r--r--tests/rest/client/v1/utils.py195
-rw-r--r--tests/rest/client/v2_alpha/test_account.py92
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py223
-rw-r--r--tests/rest/client/v2_alpha/test_capabilities.py12
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py21
-rw-r--r--tests/rest/client/v2_alpha/test_password_policy.py26
-rw-r--r--tests/rest/client/v2_alpha/test_register.py115
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py54
-rw-r--r--tests/rest/client/v2_alpha/test_shared_rooms.py17
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py50
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py12
-rw-r--r--tests/rest/media/v1/test_media_storage.py39
-rw-r--r--tests/rest/media/v1/test_url_preview.py157
-rw-r--r--tests/rest/test_health.py11
-rw-r--r--tests/rest/test_well_known.py16
-rw-r--r--tests/server.py148
-rw-r--r--tests/server_notices/test_consent.py9
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py11
-rw-r--r--tests/state/test_v2.py193
-rw-r--r--tests/storage/test__base.py299
-rw-r--r--tests/storage/test_appservice.py78
-rw-r--r--tests/storage/test_cleanup_extrems.py36
-rw-r--r--tests/storage/test_client_ips.py19
-rw-r--r--tests/storage/test_devices.py26
-rw-r--r--tests/storage/test_e2e_room_keys.py2
-rw-r--r--tests/storage/test_event_federation.py21
-rw-r--r--tests/storage/test_event_metrics.py4
-rw-r--r--tests/storage/test_events.py334
-rw-r--r--tests/storage/test_id_generators.py25
-rw-r--r--tests/storage/test_main.py7
-rw-r--r--tests/storage/test_purge.py2
-rw-r--r--tests/storage/test_redaction.py15
-rw-r--r--tests/storage/test_registration.py10
-rw-r--r--tests/storage/test_roommember.py12
-rw-r--r--tests/storage/test_user_directory.py23
-rw-r--r--tests/test_federation.py10
-rw-r--r--tests/test_mau.py50
-rw-r--r--tests/test_metrics.py4
-rw-r--r--tests/test_phone_home.py2
-rw-r--r--tests/test_preview.py27
-rw-r--r--tests/test_server.py49
-rw-r--r--tests/test_state.py1
-rw-r--r--tests/test_terms_auth.py9
-rw-r--r--tests/test_utils/__init__.py73
-rw-r--r--tests/test_utils/event_injection.py2
-rw-r--r--tests/test_utils/logging_setup.py2
-rw-r--r--tests/unittest.py167
-rw-r--r--tests/util/caches/test_deferred_cache.py251
-rw-r--r--tests/util/caches/test_descriptors.py381
-rw-r--r--tests/util/test_lrucache.py12
-rw-r--r--tests/utils.py253
-rw-r--r--tox.ini59
534 files changed, 25926 insertions, 11691 deletions
diff --git a/.buildkite/scripts/test_old_deps.sh b/.buildkite/scripts/test_old_deps.sh

index cdb77b556c..9905c4bc4f 100755 --- a/.buildkite/scripts/test_old_deps.sh +++ b/.buildkite/scripts/test_old_deps.sh
@@ -6,7 +6,7 @@ set -ex apt-get update -apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox +apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox export LANG="C.UTF-8" diff --git a/.buildkite/test_db.db b/.buildkite/test_db.db
index 361369a581..a0d9f16a75 100644 --- a/.buildkite/test_db.db +++ b/.buildkite/test_db.db
Binary files differdiff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist
index fd98cbbaf6..5975cb98cf 100644 --- a/.buildkite/worker-blacklist +++ b/.buildkite/worker-blacklist
@@ -1,41 +1,10 @@ # This file serves as a blacklist for SyTest tests that we expect will fail in # Synapse when run under worker mode. For more details, see sytest-blacklist. -Message history can be paginated - Can re-join room if re-invited -The only membership state included in an initial sync is for all the senders in the timeline - -Local device key changes get to remote servers - -If remote user leaves room we no longer receive device updates - -Forgotten room messages cannot be paginated - -Inbound federation can get public room list - -Members from the gap are included in gappy incr LL sync - -Leaves are present in non-gapped incremental syncs - -Old leaves are present in gapped incremental syncs - -User sees updates to presence from other users in the incremental sync. - -Gapped incremental syncs include all state changes - -Old members are included in gappy incr LL sync if they start speaking - # new failures as of https://github.com/matrix-org/sytest/pull/732 Device list doesn't change if remote server is down -Remote servers cannot set power levels in rooms without existing powerlevels -Remote servers should reject attempts by non-creators to set the power levels # https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5 Server correctly handles incoming m.device_list_update - -# this fails reliably with a torture level of 100 due to https://github.com/matrix-org/synapse/issues/6536 -Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state - -Can get rooms/{roomId}/members at a given point diff --git a/.circleci/config.yml b/.circleci/config.yml
index 5bd2ab2b76..375a7f7b04 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml
@@ -1,22 +1,35 @@ -version: 2 +version: 2.1 jobs: dockerhubuploadrelease: - machine: true + docker: + - image: docker:git steps: - checkout - - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:${CIRCLE_TAG} . + - docker_prepare - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD - - run: docker push matrixdotorg/synapse:${CIRCLE_TAG} + # for release builds, we want to get the amd64 image out asap, so first + # we do an amd64-only build, before following up with a multiarch build. + - docker_build: + tag: -t matrixdotorg/synapse:${CIRCLE_TAG} + platforms: linux/amd64 + - docker_build: + tag: -t matrixdotorg/synapse:${CIRCLE_TAG} + platforms: linux/amd64,linux/arm/v7,linux/arm64 + dockerhubuploadlatest: - machine: true + docker: + - image: docker:git steps: - checkout - - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:latest . + - docker_prepare - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD - - run: docker push matrixdotorg/synapse:latest + # for `latest`, we don't want the arm images to disappear, so don't update the tag + # until all of the platforms are built. + - docker_build: + tag: -t matrixdotorg/synapse:latest + platforms: linux/amd64,linux/arm/v7,linux/arm64 workflows: - version: 2 build: jobs: - dockerhubuploadrelease: @@ -29,3 +42,37 @@ workflows: filters: branches: only: master + +commands: + docker_prepare: + description: Sets up a remote docker server, downloads the buildx cli plugin, and enables multiarch images + parameters: + buildx_version: + type: string + default: "v0.4.1" + steps: + - setup_remote_docker: + # 19.03.13 was the most recent available on circleci at the time of + # writing. + version: 19.03.13 + - run: apk add --no-cache curl + - run: mkdir -vp ~/.docker/cli-plugins/ ~/dockercache + - run: curl --silent -L "https://github.com/docker/buildx/releases/download/<< parameters.buildx_version >>/buildx-<< parameters.buildx_version >>.linux-amd64" > ~/.docker/cli-plugins/docker-buildx + - run: chmod a+x ~/.docker/cli-plugins/docker-buildx + # install qemu links in /proc/sys/fs/binfmt_misc on the docker instance running the circleci job + - run: docker run --rm --privileged multiarch/qemu-user-static --reset -p yes + # create a context named `builder` for the builds + - run: docker context create builder + # create a buildx builder using the new context, and set it as the default + - run: docker buildx create builder --use + + docker_build: + description: Builds and pushed images to dockerhub using buildx + parameters: + platforms: + type: string + default: linux/amd64 + tag: + type: string + steps: + - run: docker buildx build -f docker/Dockerfile --push --platform << parameters.platforms >> --label gitsha1=${CIRCLE_SHA1} << parameters.tag >> --progress=plain . diff --git a/.gitignore b/.gitignore
index af36c00cfa..9bb5bdd647 100644 --- a/.gitignore +++ b/.gitignore
@@ -21,6 +21,7 @@ _trial_temp*/ /.python-version /*.signing.key /env/ +/.venv*/ /homeserver*.yaml /logs /media_store/ diff --git a/CHANGES.md b/CHANGES.md
index 75dc5fa893..db11de0e85 100644 --- a/CHANGES.md +++ b/CHANGES.md
@@ -1,3 +1,571 @@ +Synapse 1.25.0 (2021-01-13) +=========================== + +Ending Support for Python 3.5 and Postgres 9.5 +---------------------------------------------- + +With this release, the Synapse team is announcing a formal deprecation policy for our platform dependencies, like Python and PostgreSQL: + +All future releases of Synapse will follow the upstream end-of-life schedules. + +Which means: + +* This is the last release which guarantees support for Python 3.5. +* We will end support for PostgreSQL 9.5 early next month. +* We will end support for Python 3.6 and PostgreSQL 9.6 near the end of the year. + +Crucially, this means __we will not produce .deb packages for Debian 9 (Stretch) or Ubuntu 16.04 (Xenial)__ beyond the transition period described below. + +The website https://endoflife.date/ has convenient summaries of the support schedules for projects like [Python](https://endoflife.date/python) and [PostgreSQL](https://endoflife.date/postgresql). + +If you are unable to upgrade your environment to a supported version of Python or Postgres, we encourage you to consider using the [Synapse Docker images](./INSTALL.md#docker-images-and-ansible-playbooks) instead. + +### Transition Period + +We will make a good faith attempt to avoid breaking compatibility in all releases through the end of March 2021. However, critical security vulnerabilities in dependencies or other unanticipated circumstances may arise which necessitate breaking compatibility earlier. + +We intend to continue producing .deb packages for Debian 9 (Stretch) and Ubuntu 16.04 (Xenial) through the transition period. + +Removal warning +--------------- + +The old [Purge Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/purge_room.md) +and [Shutdown Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/shutdown_room.md) +are deprecated and will be removed in a future release. They will be replaced by the +[Delete Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/rooms.md#delete-room-api). + +`POST /_synapse/admin/v1/rooms/<room_id>/delete` replaces `POST /_synapse/admin/v1/purge_room` and +`POST /_synapse/admin/v1/shutdown_room/<room_id>`. + +Bugfixes +-------- + +- Fix HTTP proxy support when using a proxy that is on a blacklisted IP. Introduced in v1.25.0rc1. Contributed by @Bubu. ([\#9084](https://github.com/matrix-org/synapse/issues/9084)) + + +Synapse 1.25.0rc1 (2021-01-06) +============================== + +Features +-------- + +- Add an admin API that lets server admins get power in rooms in which local users have power. ([\#8756](https://github.com/matrix-org/synapse/issues/8756)) +- Add optional HTTP authentication to replication endpoints. ([\#8853](https://github.com/matrix-org/synapse/issues/8853)) +- Improve the error messages printed as a result of configuration problems for extension modules. ([\#8874](https://github.com/matrix-org/synapse/issues/8874)) +- Add the number of local devices to Room Details Admin API. Contributed by @dklimpel. ([\#8886](https://github.com/matrix-org/synapse/issues/8886)) +- Add `X-Robots-Tag` header to stop web crawlers from indexing media. Contributed by Aaron Raimist. ([\#8887](https://github.com/matrix-org/synapse/issues/8887)) +- Spam-checkers may now define their methods as `async`. ([\#8890](https://github.com/matrix-org/synapse/issues/8890)) +- Add support for allowing users to pick their own user ID during a single-sign-on login. ([\#8897](https://github.com/matrix-org/synapse/issues/8897), [\#8900](https://github.com/matrix-org/synapse/issues/8900), [\#8911](https://github.com/matrix-org/synapse/issues/8911), [\#8938](https://github.com/matrix-org/synapse/issues/8938), [\#8941](https://github.com/matrix-org/synapse/issues/8941), [\#8942](https://github.com/matrix-org/synapse/issues/8942), [\#8951](https://github.com/matrix-org/synapse/issues/8951)) +- Add an `email.invite_client_location` configuration option to send a web client location to the invite endpoint on the identity server which allows customisation of the email template. ([\#8930](https://github.com/matrix-org/synapse/issues/8930)) +- The search term in the list room and list user Admin APIs is now treated as case-insensitive. ([\#8931](https://github.com/matrix-org/synapse/issues/8931)) +- Apply an IP range blacklist to push and key revocation requests. ([\#8821](https://github.com/matrix-org/synapse/issues/8821), [\#8870](https://github.com/matrix-org/synapse/issues/8870), [\#8954](https://github.com/matrix-org/synapse/issues/8954)) +- Add an option to allow re-use of user-interactive authentication sessions for a period of time. ([\#8970](https://github.com/matrix-org/synapse/issues/8970)) +- Allow running the redact endpoint on workers. ([\#8994](https://github.com/matrix-org/synapse/issues/8994)) + + +Bugfixes +-------- + +- Fix bug where we might not correctly calculate the current state for rooms with multiple extremities. ([\#8827](https://github.com/matrix-org/synapse/issues/8827)) +- Fix a long-standing bug in the register admin endpoint (`/_synapse/admin/v1/register`) when the `mac` field was not provided. The endpoint now properly returns a 400 error. Contributed by @edwargix. ([\#8837](https://github.com/matrix-org/synapse/issues/8837)) +- Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password. ([\#8858](https://github.com/matrix-org/synapse/issues/8858)) +- Fix a longstanding bug where a 500 error would be returned if the `Content-Length` header was not provided to the upload media resource. ([\#8862](https://github.com/matrix-org/synapse/issues/8862)) +- Add additional validation to pusher URLs to be compliant with the specification. ([\#8865](https://github.com/matrix-org/synapse/issues/8865)) +- Fix the error code that is returned when a user tries to register on a homeserver on which new-user registration has been disabled. ([\#8867](https://github.com/matrix-org/synapse/issues/8867)) +- Fix a bug where `PUT /_synapse/admin/v2/users/<user_id>` failed to create a new user when `avatar_url` is specified. Bug introduced in Synapse v1.9.0. ([\#8872](https://github.com/matrix-org/synapse/issues/8872)) +- Fix a 500 error when attempting to preview an empty HTML file. ([\#8883](https://github.com/matrix-org/synapse/issues/8883)) +- Fix occasional deadlock when handling SIGHUP. ([\#8918](https://github.com/matrix-org/synapse/issues/8918)) +- Fix login API to not ratelimit application services that have ratelimiting disabled. ([\#8920](https://github.com/matrix-org/synapse/issues/8920)) +- Fix bug where we ratelimited auto joining of rooms on registration (using `auto_join_rooms` config). ([\#8921](https://github.com/matrix-org/synapse/issues/8921)) +- Fix a bug where deactivated users appeared in the user directory when their profile information was updated. ([\#8933](https://github.com/matrix-org/synapse/issues/8933), [\#8964](https://github.com/matrix-org/synapse/issues/8964)) +- Fix bug introduced in Synapse v1.24.0 which would cause an exception on startup if both `enabled` and `localdb_enabled` were set to `False` in the `password_config` setting of the configuration file. ([\#8937](https://github.com/matrix-org/synapse/issues/8937)) +- Fix a bug where 500 errors would be returned if the `m.room_history_visibility` event had invalid content. ([\#8945](https://github.com/matrix-org/synapse/issues/8945)) +- Fix a bug causing common English words to not be considered for a user directory search. ([\#8959](https://github.com/matrix-org/synapse/issues/8959)) +- Fix bug where application services couldn't register new ghost users if the server had reached its MAU limit. ([\#8962](https://github.com/matrix-org/synapse/issues/8962)) +- Fix a long-standing bug where a `m.image` event without a `url` would cause errors on push. ([\#8965](https://github.com/matrix-org/synapse/issues/8965)) +- Fix a small bug in v2 state resolution algorithm, which could also cause performance issues for rooms with large numbers of power levels. ([\#8971](https://github.com/matrix-org/synapse/issues/8971)) +- Add validation to the `sendToDevice` API to raise a missing parameters error instead of a 500 error. ([\#8975](https://github.com/matrix-org/synapse/issues/8975)) +- Add validation of group IDs to raise a 400 error instead of a 500 eror. ([\#8977](https://github.com/matrix-org/synapse/issues/8977)) + + +Improved Documentation +---------------------- + +- Fix the "Event persist rate" section of the included grafana dashboard by adding missing prometheus rules. ([\#8802](https://github.com/matrix-org/synapse/issues/8802)) +- Combine related media admin API docs. ([\#8839](https://github.com/matrix-org/synapse/issues/8839)) +- Fix an error in the documentation for the SAML username mapping provider. ([\#8873](https://github.com/matrix-org/synapse/issues/8873)) +- Clarify comments around template directories in `sample_config.yaml`. ([\#8891](https://github.com/matrix-org/synapse/issues/8891)) +- Move instructions for database setup, adjusted heading levels and improved syntax highlighting in [INSTALL.md](../INSTALL.md). Contributed by @fossterer. ([\#8987](https://github.com/matrix-org/synapse/issues/8987)) +- Update the example value of `group_creation_prefix` in the sample configuration. ([\#8992](https://github.com/matrix-org/synapse/issues/8992)) +- Link the Synapse developer room to the development section in the docs. ([\#9002](https://github.com/matrix-org/synapse/issues/9002)) + + +Deprecations and Removals +------------------------- + +- Deprecate Shutdown Room and Purge Room Admin APIs. ([\#8829](https://github.com/matrix-org/synapse/issues/8829)) + + +Internal Changes +---------------- + +- Properly store the mapping of external ID to Matrix ID for CAS users. ([\#8856](https://github.com/matrix-org/synapse/issues/8856), [\#8958](https://github.com/matrix-org/synapse/issues/8958)) +- Remove some unnecessary stubbing from unit tests. ([\#8861](https://github.com/matrix-org/synapse/issues/8861)) +- Remove unused `FakeResponse` class from unit tests. ([\#8864](https://github.com/matrix-org/synapse/issues/8864)) +- Pass `room_id` to `get_auth_chain_difference`. ([\#8879](https://github.com/matrix-org/synapse/issues/8879)) +- Add type hints to push module. ([\#8880](https://github.com/matrix-org/synapse/issues/8880), [\#8882](https://github.com/matrix-org/synapse/issues/8882), [\#8901](https://github.com/matrix-org/synapse/issues/8901), [\#8940](https://github.com/matrix-org/synapse/issues/8940), [\#8943](https://github.com/matrix-org/synapse/issues/8943), [\#9020](https://github.com/matrix-org/synapse/issues/9020)) +- Simplify logic for handling user-interactive-auth via single-sign-on servers. ([\#8881](https://github.com/matrix-org/synapse/issues/8881)) +- Skip the SAML tests if the requirements (`pysaml2` and `xmlsec1`) aren't available. ([\#8905](https://github.com/matrix-org/synapse/issues/8905)) +- Fix multiarch docker image builds. ([\#8906](https://github.com/matrix-org/synapse/issues/8906)) +- Don't publish `latest` docker image until all archs are built. ([\#8909](https://github.com/matrix-org/synapse/issues/8909)) +- Various clean-ups to the structured logging and logging context code. ([\#8916](https://github.com/matrix-org/synapse/issues/8916), [\#8935](https://github.com/matrix-org/synapse/issues/8935)) +- Automatically drop stale forward-extremities under some specific conditions. ([\#8929](https://github.com/matrix-org/synapse/issues/8929)) +- Refactor test utilities for injecting HTTP requests. ([\#8946](https://github.com/matrix-org/synapse/issues/8946)) +- Add a maximum size of 50 kilobytes to .well-known lookups. ([\#8950](https://github.com/matrix-org/synapse/issues/8950)) +- Fix bug in `generate_log_config` script which made it write empty files. ([\#8952](https://github.com/matrix-org/synapse/issues/8952)) +- Clean up tox.ini file; disable coverage checking for non-test runs. ([\#8963](https://github.com/matrix-org/synapse/issues/8963)) +- Add type hints to the admin and room list handlers. ([\#8973](https://github.com/matrix-org/synapse/issues/8973)) +- Add type hints to the receipts and user directory handlers. ([\#8976](https://github.com/matrix-org/synapse/issues/8976)) +- Drop the unused `local_invites` table. ([\#8979](https://github.com/matrix-org/synapse/issues/8979)) +- Add type hints to the base storage code. ([\#8980](https://github.com/matrix-org/synapse/issues/8980)) +- Support using PyJWT v2.0.0 in the test suite. ([\#8986](https://github.com/matrix-org/synapse/issues/8986)) +- Fix `tests.federation.transport.RoomDirectoryFederationTests` and ensure it runs in CI. ([\#8998](https://github.com/matrix-org/synapse/issues/8998)) +- Add type hints to the crypto module. ([\#8999](https://github.com/matrix-org/synapse/issues/8999)) + + +Synapse 1.24.0 (2020-12-09) +=========================== + +Due to the two security issues highlighted below, server administrators are +encouraged to update Synapse. We are not aware of these vulnerabilities being +exploited in the wild. + +Security advisory +----------------- + +The following issues are fixed in v1.23.1 and v1.24.0. + +- There is a denial of service attack + ([CVE-2020-26257](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26257)) + against the federation APIs in which future events will not be correctly sent + to other servers over federation. This affects all servers that participate in + open federation. (Fixed in [#8776](https://github.com/matrix-org/synapse/pull/8776)). + +- Synapse may be affected by OpenSSL + [CVE-2020-1971](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-1971). + Synapse administrators should ensure that they have the latest versions of + the cryptography Python package installed. + +To upgrade Synapse along with the cryptography package: + +* Administrators using the [`matrix.org` Docker + image](https://hub.docker.com/r/matrixdotorg/synapse/) or the [Debian/Ubuntu + packages from + `matrix.org`](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#matrixorg-packages) + should ensure that they have version 1.24.0 or 1.23.1 installed: these images include + the updated packages. +* Administrators who have [installed Synapse from + source](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#installing-from-source) + should upgrade the cryptography package within their virtualenv by running: + ```sh + <path_to_virtualenv>/bin/pip install 'cryptography>=3.3' + ``` +* Administrators who have installed Synapse from distribution packages should + consult the information from their distributions. + +Internal Changes +---------------- + +- Add a maximum version for pysaml2 on Python 3.5. ([\#8898](https://github.com/matrix-org/synapse/issues/8898)) + + +Synapse 1.23.1 (2020-12-09) +=========================== + +Due to the two security issues highlighted below, server administrators are +encouraged to update Synapse. We are not aware of these vulnerabilities being +exploited in the wild. + +Security advisory +----------------- + +The following issues are fixed in v1.23.1 and v1.24.0. + +- There is a denial of service attack + ([CVE-2020-26257](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26257)) + against the federation APIs in which future events will not be correctly sent + to other servers over federation. This affects all servers that participate in + open federation. (Fixed in [#8776](https://github.com/matrix-org/synapse/pull/8776)). + +- Synapse may be affected by OpenSSL + [CVE-2020-1971](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-1971). + Synapse administrators should ensure that they have the latest versions of + the cryptography Python package installed. + +To upgrade Synapse along with the cryptography package: + +* Administrators using the [`matrix.org` Docker + image](https://hub.docker.com/r/matrixdotorg/synapse/) or the [Debian/Ubuntu + packages from + `matrix.org`](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#matrixorg-packages) + should ensure that they have version 1.24.0 or 1.23.1 installed: these images include + the updated packages. +* Administrators who have [installed Synapse from + source](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#installing-from-source) + should upgrade the cryptography package within their virtualenv by running: + ```sh + <path_to_virtualenv>/bin/pip install 'cryptography>=3.3' + ``` +* Administrators who have installed Synapse from distribution packages should + consult the information from their distributions. + +Bugfixes +-------- + +- Fix a bug in some federation APIs which could lead to unexpected behaviour if different parameters were set in the URI and the request body. ([\#8776](https://github.com/matrix-org/synapse/issues/8776)) + + +Internal Changes +---------------- + +- Add a maximum version for pysaml2 on Python 3.5. ([\#8898](https://github.com/matrix-org/synapse/issues/8898)) + + +Synapse 1.24.0rc2 (2020-12-04) +============================== + +Bugfixes +-------- + +- Fix a regression in v1.24.0rc1 which failed to allow SAML mapping providers which were unable to redirect users to an additional page. ([\#8878](https://github.com/matrix-org/synapse/issues/8878)) + + +Internal Changes +---------------- + +- Add support for the `prometheus_client` newer than 0.9.0. Contributed by Jordan Bancino. ([\#8875](https://github.com/matrix-org/synapse/issues/8875)) + + +Synapse 1.24.0rc1 (2020-12-02) +============================== + +Features +-------- + +- Add admin API for logging in as a user. ([\#8617](https://github.com/matrix-org/synapse/issues/8617)) +- Allow specification of the SAML IdP if the metadata returns multiple IdPs. ([\#8630](https://github.com/matrix-org/synapse/issues/8630)) +- Add support for re-trying generation of a localpart for OpenID Connect mapping providers. ([\#8801](https://github.com/matrix-org/synapse/issues/8801), [\#8855](https://github.com/matrix-org/synapse/issues/8855)) +- Allow the `Date` header through CORS. Contributed by Nicolas Chamo. ([\#8804](https://github.com/matrix-org/synapse/issues/8804)) +- Add a config option, `push.group_by_unread_count`, which controls whether unread message counts in push notifications are defined as "the number of rooms with unread messages" or "total unread messages". ([\#8820](https://github.com/matrix-org/synapse/issues/8820)) +- Add `force_purge` option to delete-room admin api. ([\#8843](https://github.com/matrix-org/synapse/issues/8843)) + + +Bugfixes +-------- + +- Fix a bug where appservices may be sent an excessive amount of read receipts and presence. Broke in v1.22.0. ([\#8744](https://github.com/matrix-org/synapse/issues/8744)) +- Fix a bug in some federation APIs which could lead to unexpected behaviour if different parameters were set in the URI and the request body. ([\#8776](https://github.com/matrix-org/synapse/issues/8776)) +- Fix a bug where synctl could spawn duplicate copies of a worker. Contributed by Waylon Cude. ([\#8798](https://github.com/matrix-org/synapse/issues/8798)) +- Allow per-room profiles to be used for the server notice user. ([\#8799](https://github.com/matrix-org/synapse/issues/8799)) +- Fix a bug where logging could break after a call to SIGHUP. ([\#8817](https://github.com/matrix-org/synapse/issues/8817)) +- Fix `register_new_matrix_user` failing with "Bad Request" when trailing slash is included in server URL. Contributed by @angdraug. ([\#8823](https://github.com/matrix-org/synapse/issues/8823)) +- Fix a minor long-standing bug in login, where we would offer the `password` login type if a custom auth provider supported it, even if password login was disabled. ([\#8835](https://github.com/matrix-org/synapse/issues/8835)) +- Fix a long-standing bug which caused Synapse to require unspecified parameters during user-interactive authentication. ([\#8848](https://github.com/matrix-org/synapse/issues/8848)) +- Fix a bug introduced in v1.20.0 where the user-agent and IP address reported during user registration for CAS, OpenID Connect, and SAML were of the wrong form. ([\#8784](https://github.com/matrix-org/synapse/issues/8784)) + + +Improved Documentation +---------------------- + +- Clarify the usecase for a msisdn delegate. Contributed by Adrian Wannenmacher. ([\#8734](https://github.com/matrix-org/synapse/issues/8734)) +- Remove extraneous comma from JSON example in User Admin API docs. ([\#8771](https://github.com/matrix-org/synapse/issues/8771)) +- Update `turn-howto.md` with troubleshooting notes. ([\#8779](https://github.com/matrix-org/synapse/issues/8779)) +- Fix the example on how to set the `Content-Type` header in nginx for the Client Well-Known URI. ([\#8793](https://github.com/matrix-org/synapse/issues/8793)) +- Improve the documentation for the admin API to list all media in a room with respect to encrypted events. ([\#8795](https://github.com/matrix-org/synapse/issues/8795)) +- Update the formatting of the `push` section of the homeserver config file to better align with the [code style guidelines](https://github.com/matrix-org/synapse/blob/develop/docs/code_style.md#configuration-file-format). ([\#8818](https://github.com/matrix-org/synapse/issues/8818)) +- Improve documentation how to configure prometheus for workers. ([\#8822](https://github.com/matrix-org/synapse/issues/8822)) +- Update example prometheus console. ([\#8824](https://github.com/matrix-org/synapse/issues/8824)) + + +Deprecations and Removals +------------------------- + +- Remove old `/_matrix/client/*/admin` endpoints which were deprecated since Synapse 1.20.0. ([\#8785](https://github.com/matrix-org/synapse/issues/8785)) +- Disable pretty printing JSON responses for curl. Users who want pretty-printed output should use [jq](https://stedolan.github.io/jq/) in combination with curl. Contributed by @tulir. ([\#8833](https://github.com/matrix-org/synapse/issues/8833)) + + +Internal Changes +---------------- + +- Simplify the way the `HomeServer` object caches its internal attributes. ([\#8565](https://github.com/matrix-org/synapse/issues/8565), [\#8851](https://github.com/matrix-org/synapse/issues/8851)) +- Add an example and documentation for clock skew to the SAML2 sample configuration to allow for clock/time difference between the homserver and IdP. Contributed by @localguru. ([\#8731](https://github.com/matrix-org/synapse/issues/8731)) +- Generalise `RoomMemberHandler._locally_reject_invite` to apply to more flows than just invite. ([\#8751](https://github.com/matrix-org/synapse/issues/8751)) +- Generalise `RoomStore.maybe_store_room_on_invite` to handle other, non-invite membership events. ([\#8754](https://github.com/matrix-org/synapse/issues/8754)) +- Refactor test utilities for injecting HTTP requests. ([\#8757](https://github.com/matrix-org/synapse/issues/8757), [\#8758](https://github.com/matrix-org/synapse/issues/8758), [\#8759](https://github.com/matrix-org/synapse/issues/8759), [\#8760](https://github.com/matrix-org/synapse/issues/8760), [\#8761](https://github.com/matrix-org/synapse/issues/8761), [\#8777](https://github.com/matrix-org/synapse/issues/8777)) +- Consolidate logic between the OpenID Connect and SAML code. ([\#8765](https://github.com/matrix-org/synapse/issues/8765)) +- Use `TYPE_CHECKING` instead of magic `MYPY` variable. ([\#8770](https://github.com/matrix-org/synapse/issues/8770)) +- Add a commandline script to sign arbitrary json objects. ([\#8772](https://github.com/matrix-org/synapse/issues/8772)) +- Minor log line improvements for the SSO mapping code used to generate Matrix IDs from SSO IDs. ([\#8773](https://github.com/matrix-org/synapse/issues/8773)) +- Add additional error checking for OpenID Connect and SAML mapping providers. ([\#8774](https://github.com/matrix-org/synapse/issues/8774), [\#8800](https://github.com/matrix-org/synapse/issues/8800)) +- Add type hints to HTTP abstractions. ([\#8806](https://github.com/matrix-org/synapse/issues/8806), [\#8812](https://github.com/matrix-org/synapse/issues/8812)) +- Remove unnecessary function arguments and add typing to several membership replication classes. ([\#8809](https://github.com/matrix-org/synapse/issues/8809)) +- Optimise the lookup for an invite from another homeserver when trying to reject it. ([\#8815](https://github.com/matrix-org/synapse/issues/8815)) +- Add tests for `password_auth_provider`s. ([\#8819](https://github.com/matrix-org/synapse/issues/8819)) +- Drop redundant database index on `event_json`. ([\#8845](https://github.com/matrix-org/synapse/issues/8845)) +- Simplify `uk.half-shot.msc2778.login.application_service` login handler. ([\#8847](https://github.com/matrix-org/synapse/issues/8847)) +- Refactor `password_auth_provider` support code. ([\#8849](https://github.com/matrix-org/synapse/issues/8849)) +- Add missing `ordering` to background database updates. ([\#8850](https://github.com/matrix-org/synapse/issues/8850)) +- Allow for specifying a room version when creating a room in unit tests via `RestHelper.create_room_as`. ([\#8854](https://github.com/matrix-org/synapse/issues/8854)) + + +Synapse 1.23.0 (2020-11-18) +=========================== + +This release changes the way structured logging is configured. See the [upgrade notes](UPGRADE.rst#upgrading-to-v1230) for details. + +**Note**: We are aware of a trivially exploitable denial of service vulnerability in versions of Synapse prior to 1.20.0. Complete details will be disclosed on Monday, November 23rd. If you have not upgraded recently, please do so. + +Bugfixes +-------- + +- Fix a dependency versioning bug in the Dockerfile that prevented Synapse from starting. ([\#8767](https://github.com/matrix-org/synapse/issues/8767)) + + +Synapse 1.23.0rc1 (2020-11-13) +============================== + +Features +-------- + +- Add a push rule that highlights when a jitsi conference is created in a room. ([\#8286](https://github.com/matrix-org/synapse/issues/8286)) +- Add an admin api to delete a single file or files that were not used for a defined time from server. Contributed by @dklimpel. ([\#8519](https://github.com/matrix-org/synapse/issues/8519)) +- Split admin API for reported events (`GET /_synapse/admin/v1/event_reports`) into detail and list endpoints. This is a breaking change to #8217 which was introduced in Synapse v1.21.0. Those who already use this API should check their scripts. Contributed by @dklimpel. ([\#8539](https://github.com/matrix-org/synapse/issues/8539)) +- Support generating structured logs via the standard logging configuration. ([\#8607](https://github.com/matrix-org/synapse/issues/8607), [\#8685](https://github.com/matrix-org/synapse/issues/8685)) +- Add an admin API to allow server admins to list users' pushers. Contributed by @dklimpel. ([\#8610](https://github.com/matrix-org/synapse/issues/8610), [\#8689](https://github.com/matrix-org/synapse/issues/8689)) +- Add an admin API `GET /_synapse/admin/v1/users/<user_id>/media` to get information about uploaded media. Contributed by @dklimpel. ([\#8647](https://github.com/matrix-org/synapse/issues/8647)) +- Add an admin API for local user media statistics. Contributed by @dklimpel. ([\#8700](https://github.com/matrix-org/synapse/issues/8700)) +- Add `displayname` to Shared-Secret Registration for admins. ([\#8722](https://github.com/matrix-org/synapse/issues/8722)) + + +Bugfixes +-------- + +- Fix fetching of E2E cross signing keys over federation when only one of the master key and device signing key is cached already. ([\#8455](https://github.com/matrix-org/synapse/issues/8455)) +- Fix a bug where Synapse would blindly forward bad responses from federation to clients when retrieving profile information. ([\#8580](https://github.com/matrix-org/synapse/issues/8580)) +- Fix a bug where the account validity endpoint would silently fail if the user ID did not have an expiration time. It now returns a 400 error. ([\#8620](https://github.com/matrix-org/synapse/issues/8620)) +- Fix email notifications for invites without local state. ([\#8627](https://github.com/matrix-org/synapse/issues/8627)) +- Fix handling of invalid group IDs to return a 400 rather than log an exception and return a 500. ([\#8628](https://github.com/matrix-org/synapse/issues/8628)) +- Fix handling of User-Agent headers that are invalid UTF-8, which caused user agents of users to not get correctly recorded. ([\#8632](https://github.com/matrix-org/synapse/issues/8632)) +- Fix a bug in the `joined_rooms` admin API if the user has never joined any rooms. The bug was introduced, along with the API, in v1.21.0. ([\#8643](https://github.com/matrix-org/synapse/issues/8643)) +- Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories. ([\#8682](https://github.com/matrix-org/synapse/issues/8682)) +- Fix bug that prevented Synapse from recovering after losing connection to the database. ([\#8726](https://github.com/matrix-org/synapse/issues/8726)) +- Fix bug where the `/_synapse/admin/v1/send_server_notice` API could send notices to non-notice rooms. ([\#8728](https://github.com/matrix-org/synapse/issues/8728)) +- Fix PostgreSQL port script fails when DB has no backfilled events. Broke in v1.21.0. ([\#8729](https://github.com/matrix-org/synapse/issues/8729)) +- Fix PostgreSQL port script to correctly handle foreign key constraints. Broke in v1.21.0. ([\#8730](https://github.com/matrix-org/synapse/issues/8730)) +- Fix PostgreSQL port script so that it can be run again after a failure. Broke in v1.21.0. ([\#8755](https://github.com/matrix-org/synapse/issues/8755)) + + +Improved Documentation +---------------------- + +- Instructions for Azure AD in the OpenID Connect documentation. Contributed by peterk. ([\#8582](https://github.com/matrix-org/synapse/issues/8582)) +- Improve the sample configuration for single sign-on providers. ([\#8635](https://github.com/matrix-org/synapse/issues/8635)) +- Fix the filepath of Dex's example config and the link to Dex's Getting Started guide in the OpenID Connect docs. ([\#8657](https://github.com/matrix-org/synapse/issues/8657)) +- Note support for Python 3.9. ([\#8665](https://github.com/matrix-org/synapse/issues/8665)) +- Minor updates to docs on running tests. ([\#8666](https://github.com/matrix-org/synapse/issues/8666)) +- Interlink prometheus/grafana documentation. ([\#8667](https://github.com/matrix-org/synapse/issues/8667)) +- Notes on SSO logins and media_repository worker. ([\#8701](https://github.com/matrix-org/synapse/issues/8701)) +- Document experimental support for running multiple event persisters. ([\#8706](https://github.com/matrix-org/synapse/issues/8706)) +- Add information regarding the various sources of, and expected contributions to, Synapse's documentation to `CONTRIBUTING.md`. ([\#8714](https://github.com/matrix-org/synapse/issues/8714)) +- Migrate documentation `docs/admin_api/event_reports` to markdown. ([\#8742](https://github.com/matrix-org/synapse/issues/8742)) +- Add some helpful hints to the README for new Synapse developers. Contributed by @chagai95. ([\#8746](https://github.com/matrix-org/synapse/issues/8746)) + + +Internal Changes +---------------- + +- Optimise `/createRoom` with multiple invited users. ([\#8559](https://github.com/matrix-org/synapse/issues/8559)) +- Implement and use an `@lru_cache` decorator. ([\#8595](https://github.com/matrix-org/synapse/issues/8595)) +- Don't instansiate Requester directly. ([\#8614](https://github.com/matrix-org/synapse/issues/8614)) +- Type hints for `RegistrationStore`. ([\#8615](https://github.com/matrix-org/synapse/issues/8615)) +- Change schema to support access tokens belonging to one user but granting access to another. ([\#8616](https://github.com/matrix-org/synapse/issues/8616)) +- Remove unused OPTIONS handlers. ([\#8621](https://github.com/matrix-org/synapse/issues/8621)) +- Run `mypy` as part of the lint.sh script. ([\#8633](https://github.com/matrix-org/synapse/issues/8633)) +- Correct Synapse's PyPI package name in the OpenID Connect installation instructions. ([\#8634](https://github.com/matrix-org/synapse/issues/8634)) +- Catch exceptions during initialization of `password_providers`. Contributed by Nicolai Søborg. ([\#8636](https://github.com/matrix-org/synapse/issues/8636)) +- Fix typos and spelling errors in the code. ([\#8639](https://github.com/matrix-org/synapse/issues/8639)) +- Reduce number of OpenTracing spans started. ([\#8640](https://github.com/matrix-org/synapse/issues/8640), [\#8668](https://github.com/matrix-org/synapse/issues/8668), [\#8670](https://github.com/matrix-org/synapse/issues/8670)) +- Add field `total` to device list in admin API. ([\#8644](https://github.com/matrix-org/synapse/issues/8644)) +- Add more type hints to the application services code. ([\#8655](https://github.com/matrix-org/synapse/issues/8655), [\#8693](https://github.com/matrix-org/synapse/issues/8693)) +- Tell Black to format code for Python 3.5. ([\#8664](https://github.com/matrix-org/synapse/issues/8664)) +- Don't pull event from DB when handling replication traffic. ([\#8669](https://github.com/matrix-org/synapse/issues/8669)) +- Abstract some invite-related code in preparation for landing knocking. ([\#8671](https://github.com/matrix-org/synapse/issues/8671), [\#8688](https://github.com/matrix-org/synapse/issues/8688)) +- Clarify representation of events in logfiles. ([\#8679](https://github.com/matrix-org/synapse/issues/8679)) +- Don't require `hiredis` package to be installed to run unit tests. ([\#8680](https://github.com/matrix-org/synapse/issues/8680)) +- Fix typing info on cache call signature to accept `on_invalidate`. ([\#8684](https://github.com/matrix-org/synapse/issues/8684)) +- Fail tests if they do not await coroutines. ([\#8690](https://github.com/matrix-org/synapse/issues/8690)) +- Improve start time by adding an index to `e2e_cross_signing_keys.stream_id`. ([\#8694](https://github.com/matrix-org/synapse/issues/8694)) +- Re-organize the structured logging code to separate the TCP transport handling from the JSON formatting. ([\#8697](https://github.com/matrix-org/synapse/issues/8697)) +- Use Python 3.8 in Docker images by default. ([\#8698](https://github.com/matrix-org/synapse/issues/8698)) +- Remove the "draft" status of the Room Details Admin API. ([\#8702](https://github.com/matrix-org/synapse/issues/8702)) +- Improve the error returned when a non-string displayname or avatar_url is used when updating a user's profile. ([\#8705](https://github.com/matrix-org/synapse/issues/8705)) +- Block attempts by clients to send server ACLs, or redactions of server ACLs, that would result in the local server being blocked from the room. ([\#8708](https://github.com/matrix-org/synapse/issues/8708)) +- Add metrics the allow the local sysadmin to track 3PID `/requestToken` requests. ([\#8712](https://github.com/matrix-org/synapse/issues/8712)) +- Consolidate duplicated lists of purged tables that are checked in tests. ([\#8713](https://github.com/matrix-org/synapse/issues/8713)) +- Add some `mdui:UIInfo` element examples for `saml2_config` in the homeserver config. ([\#8718](https://github.com/matrix-org/synapse/issues/8718)) +- Improve the error message returned when a remote server incorrectly sets the `Content-Type` header in response to a JSON request. ([\#8719](https://github.com/matrix-org/synapse/issues/8719)) +- Speed up repeated state resolutions on the same room by caching event ID to auth event ID lookups. ([\#8752](https://github.com/matrix-org/synapse/issues/8752)) + + +Synapse 1.22.1 (2020-10-30) +=========================== + +Bugfixes +-------- + +- Fix a bug where an appservice may not be forwarded events for a room it was recently invited to. Broke in v1.22.0. ([\#8676](https://github.com/matrix-org/synapse/issues/8676)) +- Fix `Object of type frozendict is not JSON serializable` exceptions when using third-party event rules. Broke in v1.22.0. ([\#8678](https://github.com/matrix-org/synapse/issues/8678)) + + +Synapse 1.22.0 (2020-10-27) +=========================== + +No significant changes. + + +Synapse 1.22.0rc2 (2020-10-26) +============================== + +Bugfixes +-------- + +- Fix bugs where ephemeral events were not sent to appservices. Broke in v1.22.0rc1. ([\#8648](https://github.com/matrix-org/synapse/issues/8648), [\#8656](https://github.com/matrix-org/synapse/issues/8656)) +- Fix `user_daily_visits` table to not have duplicate rows per user/device due to multiple user agents. Broke in v1.22.0rc1. ([\#8654](https://github.com/matrix-org/synapse/issues/8654)) + +Synapse 1.22.0rc1 (2020-10-22) +============================== + +Features +-------- + +- Add a configuration option for always using the "userinfo endpoint" for OpenID Connect. This fixes support for some identity providers, e.g. GitLab. Contributed by Benjamin Koch. ([\#7658](https://github.com/matrix-org/synapse/issues/7658)) +- Add ability for `ThirdPartyEventRules` modules to query and manipulate whether a room is in the public rooms directory. ([\#8292](https://github.com/matrix-org/synapse/issues/8292), [\#8467](https://github.com/matrix-org/synapse/issues/8467)) +- Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)). ([\#8312](https://github.com/matrix-org/synapse/issues/8312), [\#8501](https://github.com/matrix-org/synapse/issues/8501)) +- Add support for running background tasks in a separate worker process. ([\#8369](https://github.com/matrix-org/synapse/issues/8369), [\#8458](https://github.com/matrix-org/synapse/issues/8458), [\#8489](https://github.com/matrix-org/synapse/issues/8489), [\#8513](https://github.com/matrix-org/synapse/issues/8513), [\#8544](https://github.com/matrix-org/synapse/issues/8544), [\#8599](https://github.com/matrix-org/synapse/issues/8599)) +- Add support for device dehydration ([MSC2697](https://github.com/matrix-org/matrix-doc/pull/2697)). ([\#8380](https://github.com/matrix-org/synapse/issues/8380)) +- Add support for [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409), which allows sending typing, read receipts, and presence events to appservices. ([\#8437](https://github.com/matrix-org/synapse/issues/8437), [\#8590](https://github.com/matrix-org/synapse/issues/8590)) +- Change default room version to "6", per [MSC2788](https://github.com/matrix-org/matrix-doc/pull/2788). ([\#8461](https://github.com/matrix-org/synapse/issues/8461)) +- Add the ability to send non-membership events into a room via the `ModuleApi`. ([\#8479](https://github.com/matrix-org/synapse/issues/8479)) +- Increase default upload size limit from 10M to 50M. Contributed by @Akkowicz. ([\#8502](https://github.com/matrix-org/synapse/issues/8502)) +- Add support for modifying event content in `ThirdPartyRules` modules. ([\#8535](https://github.com/matrix-org/synapse/issues/8535), [\#8564](https://github.com/matrix-org/synapse/issues/8564)) + + +Bugfixes +-------- + +- Fix a longstanding bug where invalid ignored users in account data could break clients. ([\#8454](https://github.com/matrix-org/synapse/issues/8454)) +- Fix a bug where backfilling a room with an event that was missing the `redacts` field would break. ([\#8457](https://github.com/matrix-org/synapse/issues/8457)) +- Don't attempt to respond to some requests if the client has already disconnected. ([\#8465](https://github.com/matrix-org/synapse/issues/8465)) +- Fix message duplication if something goes wrong after persisting the event. ([\#8476](https://github.com/matrix-org/synapse/issues/8476)) +- Fix incremental sync returning an incorrect `prev_batch` token in timeline section, which when used to paginate returned events that were included in the incremental sync. Broken since v0.16.0. ([\#8486](https://github.com/matrix-org/synapse/issues/8486)) +- Expose the `uk.half-shot.msc2778.login.application_service` to clients from the login API. This feature was added in v1.21.0, but was not exposed as a potential login flow. ([\#8504](https://github.com/matrix-org/synapse/issues/8504)) +- Fix error code for `/profile/{userId}/displayname` to be `M_BAD_JSON`. ([\#8517](https://github.com/matrix-org/synapse/issues/8517)) +- Fix a bug introduced in v1.7.0 that could cause Synapse to insert values from non-state `m.room.retention` events into the `room_retention` database table. ([\#8527](https://github.com/matrix-org/synapse/issues/8527)) +- Fix not sending events over federation when using sharded event writers. ([\#8536](https://github.com/matrix-org/synapse/issues/8536)) +- Fix a long standing bug where email notifications for encrypted messages were blank. ([\#8545](https://github.com/matrix-org/synapse/issues/8545)) +- Fix increase in the number of `There was no active span...` errors logged when using OpenTracing. ([\#8567](https://github.com/matrix-org/synapse/issues/8567)) +- Fix a bug that prevented errors encountered during execution of the `synapse_port_db` from being correctly printed. ([\#8585](https://github.com/matrix-org/synapse/issues/8585)) +- Fix appservice transactions to only include a maximum of 100 persistent and 100 ephemeral events. ([\#8606](https://github.com/matrix-org/synapse/issues/8606)) + + +Updates to the Docker image +--------------------------- + +- Added multi-arch support (arm64,arm/v7) for the docker images. Contributed by @maquis196. ([\#7921](https://github.com/matrix-org/synapse/issues/7921)) +- Add support for passing commandline args to the synapse process. Contributed by @samuel-p. ([\#8390](https://github.com/matrix-org/synapse/issues/8390)) + + +Improved Documentation +---------------------- + +- Update the directions for using the manhole with coroutines. ([\#8462](https://github.com/matrix-org/synapse/issues/8462)) +- Improve readme by adding new shield.io badges. ([\#8493](https://github.com/matrix-org/synapse/issues/8493)) +- Added note about docker in manhole.md regarding which ip address to bind to. Contributed by @Maquis196. ([\#8526](https://github.com/matrix-org/synapse/issues/8526)) +- Document the new behaviour of the `allowed_lifetime_min` and `allowed_lifetime_max` settings in the room retention configuration. ([\#8529](https://github.com/matrix-org/synapse/issues/8529)) + + +Deprecations and Removals +------------------------- + +- Drop unused `device_max_stream_id` table. ([\#8589](https://github.com/matrix-org/synapse/issues/8589)) + + +Internal Changes +---------------- + +- Check for unreachable code with mypy. ([\#8432](https://github.com/matrix-org/synapse/issues/8432)) +- Add unit test for event persister sharding. ([\#8433](https://github.com/matrix-org/synapse/issues/8433)) +- Allow events to be sent to clients sooner when using sharded event persisters. ([\#8439](https://github.com/matrix-org/synapse/issues/8439), [\#8488](https://github.com/matrix-org/synapse/issues/8488), [\#8496](https://github.com/matrix-org/synapse/issues/8496), [\#8499](https://github.com/matrix-org/synapse/issues/8499)) +- Configure `public_baseurl` when using demo scripts. ([\#8443](https://github.com/matrix-org/synapse/issues/8443)) +- Add SQL logging on queries that happen during startup. ([\#8448](https://github.com/matrix-org/synapse/issues/8448)) +- Speed up unit tests when using PostgreSQL. ([\#8450](https://github.com/matrix-org/synapse/issues/8450)) +- Remove redundant database loads of stream_ordering for events we already have. ([\#8452](https://github.com/matrix-org/synapse/issues/8452)) +- Reduce inconsistencies between codepaths for membership and non-membership events. ([\#8463](https://github.com/matrix-org/synapse/issues/8463)) +- Combine `SpamCheckerApi` with the more generic `ModuleApi`. ([\#8464](https://github.com/matrix-org/synapse/issues/8464)) +- Additional testing for `ThirdPartyEventRules`. ([\#8468](https://github.com/matrix-org/synapse/issues/8468)) +- Add `-d` option to `./scripts-dev/lint.sh` to lint files that have changed since the last git commit. ([\#8472](https://github.com/matrix-org/synapse/issues/8472)) +- Unblacklist some sytests. ([\#8474](https://github.com/matrix-org/synapse/issues/8474)) +- Include the log level in the phone home stats. ([\#8477](https://github.com/matrix-org/synapse/issues/8477)) +- Remove outdated sphinx documentation, scripts and configuration. ([\#8480](https://github.com/matrix-org/synapse/issues/8480)) +- Clarify error message when plugin config parsers raise an error. ([\#8492](https://github.com/matrix-org/synapse/issues/8492)) +- Remove the deprecated `Handlers` object. ([\#8494](https://github.com/matrix-org/synapse/issues/8494)) +- Fix a threadsafety bug in unit tests. ([\#8497](https://github.com/matrix-org/synapse/issues/8497)) +- Add user agent to user_daily_visits table. ([\#8503](https://github.com/matrix-org/synapse/issues/8503)) +- Add type hints to various parts of the code base. ([\#8407](https://github.com/matrix-org/synapse/issues/8407), [\#8505](https://github.com/matrix-org/synapse/issues/8505), [\#8507](https://github.com/matrix-org/synapse/issues/8507), [\#8547](https://github.com/matrix-org/synapse/issues/8547), [\#8562](https://github.com/matrix-org/synapse/issues/8562), [\#8609](https://github.com/matrix-org/synapse/issues/8609)) +- Remove unused code from the test framework. ([\#8514](https://github.com/matrix-org/synapse/issues/8514)) +- Apply some internal fixes to the `HomeServer` class to make its code more idiomatic and statically-verifiable. ([\#8515](https://github.com/matrix-org/synapse/issues/8515)) +- Factor out common code between `RoomMemberHandler._locally_reject_invite` and `EventCreationHandler.create_event`. ([\#8537](https://github.com/matrix-org/synapse/issues/8537)) +- Improve database performance by executing more queries without starting transactions. ([\#8542](https://github.com/matrix-org/synapse/issues/8542)) +- Rename `Cache` to `DeferredCache`, to better reflect its purpose. ([\#8548](https://github.com/matrix-org/synapse/issues/8548)) +- Move metric registration code down into `LruCache`. ([\#8561](https://github.com/matrix-org/synapse/issues/8561), [\#8591](https://github.com/matrix-org/synapse/issues/8591)) +- Replace `DeferredCache` with the lighter-weight `LruCache` where possible. ([\#8563](https://github.com/matrix-org/synapse/issues/8563)) +- Add virtualenv-generated folders to `.gitignore`. ([\#8566](https://github.com/matrix-org/synapse/issues/8566)) +- Add `get_immediate` method to `DeferredCache`. ([\#8568](https://github.com/matrix-org/synapse/issues/8568)) +- Fix mypy not properly checking across the codebase, additionally, fix a typing assertion error in `handlers/auth.py`. ([\#8569](https://github.com/matrix-org/synapse/issues/8569)) +- Fix `synmark` benchmark runner. ([\#8571](https://github.com/matrix-org/synapse/issues/8571)) +- Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s. ([\#8572](https://github.com/matrix-org/synapse/issues/8572)) +- Adjust a protocol-type definition to fit `sqlite3` assertions. ([\#8577](https://github.com/matrix-org/synapse/issues/8577)) +- Support macOS on the `synmark` benchmark runner. ([\#8578](https://github.com/matrix-org/synapse/issues/8578)) +- Update `mypy` static type checker to 0.790. ([\#8583](https://github.com/matrix-org/synapse/issues/8583), [\#8600](https://github.com/matrix-org/synapse/issues/8600)) +- Re-organize the structured logging code to separate the TCP transport handling from the JSON formatting. ([\#8587](https://github.com/matrix-org/synapse/issues/8587)) +- Remove extraneous unittest logging decorators from unit tests. ([\#8592](https://github.com/matrix-org/synapse/issues/8592)) +- Minor optimisations in caching code. ([\#8593](https://github.com/matrix-org/synapse/issues/8593), [\#8594](https://github.com/matrix-org/synapse/issues/8594)) + + +Synapse 1.21.2 (2020-10-15) +=========================== + +Debian packages and Docker images have been rebuilt using the latest versions of dependency libraries, including authlib 0.15.1. Please see bugfixes below. + +Security advisory +----------------- + +* HTML pages served via Synapse were vulnerable to cross-site scripting (XSS) + attacks. All server administrators are encouraged to upgrade. + ([\#8444](https://github.com/matrix-org/synapse/pull/8444)) + ([CVE-2020-26891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26891)) + + This fix was originally included in v1.21.0 but was missing a security advisory. + + This was reported by [Denis Kasak](https://github.com/dkasak). + +Bugfixes +-------- + +- Fix rare bug where sending an event would fail due to a racey assertion. ([\#8530](https://github.com/matrix-org/synapse/issues/8530)) +- An updated version of the authlib dependency is included in the Docker and Debian images to fix an issue using OpenID Connect. See [\#8534](https://github.com/matrix-org/synapse/issues/8534) for details. + + Synapse 1.21.1 (2020-10-13) =========================== @@ -6074,8 +6642,8 @@ Changes in synapse 0.5.1 (2014-11-26) See UPGRADES.rst for specific instructions on how to upgrade. -> - Fix bug where we served up an Event that did not match its signatures. -> - Fix regression where we no longer correctly handled the case where a homeserver receives an event for a room it doesn\'t recognise (but is in.) +- Fix bug where we served up an Event that did not match its signatures. +- Fix regression where we no longer correctly handled the case where a homeserver receives an event for a room it doesn\'t recognise (but is in.) Changes in synapse 0.5.0 (2014-11-19) ===================================== @@ -6086,44 +6654,44 @@ This release also changes the internal database schemas and so requires servers Homeserver: -: - Add authentication and authorization to the federation protocol. Events are now signed by their originating homeservers. - - Implement the new authorization model for rooms. - - Split out web client into a seperate repository: matrix-angular-sdk. - - Change the structure of PDUs. - - Fix bug where user could not join rooms via an alias containing 4-byte UTF-8 characters. - - Merge concept of PDUs and Events internally. - - Improve logging by adding request ids to log lines. - - Implement a very basic room initial sync API. - - Implement the new invite/join federation APIs. +- Add authentication and authorization to the federation protocol. Events are now signed by their originating homeservers. +- Implement the new authorization model for rooms. +- Split out web client into a seperate repository: matrix-angular-sdk. +- Change the structure of PDUs. +- Fix bug where user could not join rooms via an alias containing 4-byte UTF-8 characters. +- Merge concept of PDUs and Events internally. +- Improve logging by adding request ids to log lines. +- Implement a very basic room initial sync API. +- Implement the new invite/join federation APIs. Webclient: -: - The webclient has been moved to a seperate repository. +- The webclient has been moved to a seperate repository. Changes in synapse 0.4.2 (2014-10-31) ===================================== Homeserver: -: - Fix bugs where we did not notify users of correct presence updates. - - Fix bug where we did not handle sub second event stream timeouts. +- Fix bugs where we did not notify users of correct presence updates. +- Fix bug where we did not handle sub second event stream timeouts. Webclient: -: - Add ability to click on messages to see JSON. - - Add ability to redact messages. - - Add ability to view and edit all room state JSON. - - Handle incoming redactions. - - Improve feedback on errors. - - Fix bugs in mobile CSS. - - Fix bugs with desktop notifications. +- Add ability to click on messages to see JSON. +- Add ability to redact messages. +- Add ability to view and edit all room state JSON. +- Handle incoming redactions. +- Improve feedback on errors. +- Fix bugs in mobile CSS. +- Fix bugs with desktop notifications. Changes in synapse 0.4.1 (2014-10-17) ===================================== Webclient: -: - Fix bug with display of timestamps. +- Fix bug with display of timestamps. Changes in synpase 0.4.0 (2014-10-17) ===================================== @@ -6136,8 +6704,8 @@ You will also need an updated syutil and config. See UPGRADES.rst. Homeserver: -: - Sign federation transactions to assert strong identity over federation. - - Rename timestamp keys in PDUs and events from \'ts\' and \'hsob\_ts\' to \'origin\_server\_ts\'. +- Sign federation transactions to assert strong identity over federation. +- Rename timestamp keys in PDUs and events from \'ts\' and \'hsob\_ts\' to \'origin\_server\_ts\'. Changes in synapse 0.3.4 (2014-09-25) ===================================== @@ -6146,48 +6714,48 @@ This version adds support for using a TURN server. See docs/turn-howto.rst on ho Homeserver: -: - Add support for redaction of messages. - - Fix bug where inviting a user on a remote home server could take up to 20-30s. - - Implement a get current room state API. - - Add support specifying and retrieving turn server configuration. +- Add support for redaction of messages. +- Fix bug where inviting a user on a remote home server could take up to 20-30s. +- Implement a get current room state API. +- Add support specifying and retrieving turn server configuration. Webclient: -: - Add button to send messages to users from the home page. - - Add support for using TURN for VoIP calls. - - Show display name change messages. - - Fix bug where the client didn\'t get the state of a newly joined room until after it has been refreshed. - - Fix bugs with tab complete. - - Fix bug where holding down the down arrow caused chrome to chew 100% CPU. - - Fix bug where desktop notifications occasionally used \"Undefined\" as the display name. - - Fix more places where we sometimes saw room IDs incorrectly. - - Fix bug which caused lag when entering text in the text box. +- Add button to send messages to users from the home page. +- Add support for using TURN for VoIP calls. +- Show display name change messages. +- Fix bug where the client didn\'t get the state of a newly joined room until after it has been refreshed. +- Fix bugs with tab complete. +- Fix bug where holding down the down arrow caused chrome to chew 100% CPU. +- Fix bug where desktop notifications occasionally used \"Undefined\" as the display name. +- Fix more places where we sometimes saw room IDs incorrectly. +- Fix bug which caused lag when entering text in the text box. Changes in synapse 0.3.3 (2014-09-22) ===================================== Homeserver: -: - Fix bug where you continued to get events for rooms you had left. +- Fix bug where you continued to get events for rooms you had left. Webclient: -: - Add support for video calls with basic UI. - - Fix bug where one to one chats were named after your display name rather than the other person\'s. - - Fix bug which caused lag when typing in the textarea. - - Refuse to run on browsers we know won\'t work. - - Trigger pagination when joining new rooms. - - Fix bug where we sometimes didn\'t display invitations in recents. - - Automatically join room when accepting a VoIP call. - - Disable outgoing and reject incoming calls on browsers we don\'t support VoIP in. - - Don\'t display desktop notifications for messages in the room you are non-idle and speaking in. +- Add support for video calls with basic UI. +- Fix bug where one to one chats were named after your display name rather than the other person\'s. +- Fix bug which caused lag when typing in the textarea. +- Refuse to run on browsers we know won\'t work. +- Trigger pagination when joining new rooms. +- Fix bug where we sometimes didn\'t display invitations in recents. +- Automatically join room when accepting a VoIP call. +- Disable outgoing and reject incoming calls on browsers we don\'t support VoIP in. +- Don\'t display desktop notifications for messages in the room you are non-idle and speaking in. Changes in synapse 0.3.2 (2014-09-18) ===================================== Webclient: -: - Fix bug where an empty \"bing words\" list in old accounts didn\'t send notifications when it should have done. +- Fix bug where an empty \"bing words\" list in old accounts didn\'t send notifications when it should have done. Changes in synapse 0.3.1 (2014-09-18) ===================================== @@ -6196,8 +6764,8 @@ This is a release to hotfix v0.3.0 to fix two regressions. Webclient: -: - Fix a regression where we sometimes displayed duplicate events. - - Fix a regression where we didn\'t immediately remove rooms you were banned in from the recents list. +- Fix a regression where we sometimes displayed duplicate events. +- Fix a regression where we didn\'t immediately remove rooms you were banned in from the recents list. Changes in synapse 0.3.0 (2014-09-18) ===================================== @@ -6206,91 +6774,91 @@ See UPGRADE for information about changes to the client server API, including br Homeserver: -: - When a user changes their displayname or avatar the server will now update all their join states to reflect this. - - The server now adds \"age\" key to events to indicate how old they are. This is clock independent, so at no point does any server or webclient have to assume their clock is in sync with everyone else. - - Fix bug where we didn\'t correctly pull in missing PDUs. - - Fix bug where prev\_content key wasn\'t always returned. - - Add support for password resets. +- When a user changes their displayname or avatar the server will now update all their join states to reflect this. +- The server now adds \"age\" key to events to indicate how old they are. This is clock independent, so at no point does any server or webclient have to assume their clock is in sync with everyone else. +- Fix bug where we didn\'t correctly pull in missing PDUs. +- Fix bug where prev\_content key wasn\'t always returned. +- Add support for password resets. Webclient: -: - Improve page content loading. - - Join/parts now trigger desktop notifications. - - Always show room aliases in the UI if one is present. - - No longer show user-count in the recents side panel. - - Add up & down arrow support to the text box for message sending to step through your sent history. - - Don\'t display notifications for our own messages. - - Emotes are now formatted correctly in desktop notifications. - - The recents list now differentiates between public & private rooms. - - Fix bug where when switching between rooms the pagination flickered before the view jumped to the bottom of the screen. - - Add bing word support. +- Improve page content loading. +- Join/parts now trigger desktop notifications. +- Always show room aliases in the UI if one is present. +- No longer show user-count in the recents side panel. +- Add up & down arrow support to the text box for message sending to step through your sent history. +- Don\'t display notifications for our own messages. +- Emotes are now formatted correctly in desktop notifications. +- The recents list now differentiates between public & private rooms. +- Fix bug where when switching between rooms the pagination flickered before the view jumped to the bottom of the screen. +- Add bing word support. Registration API: -: - The registration API has been overhauled to function like the login API. In practice, this means registration requests must now include the following: \'type\':\'m.login.password\'. See UPGRADE for more information on this. - - The \'user\_id\' key has been renamed to \'user\' to better match the login API. - - There is an additional login type: \'m.login.email.identity\'. - - The command client and web client have been updated to reflect these changes. +- The registration API has been overhauled to function like the login API. In practice, this means registration requests must now include the following: \'type\':\'m.login.password\'. See UPGRADE for more information on this. +- The \'user\_id\' key has been renamed to \'user\' to better match the login API. +- There is an additional login type: \'m.login.email.identity\'. +- The command client and web client have been updated to reflect these changes. Changes in synapse 0.2.3 (2014-09-12) ===================================== Homeserver: -: - Fix bug where we stopped sending events to remote home servers if a user from that home server left, even if there were some still in the room. - - Fix bugs in the state conflict resolution where it was incorrectly rejecting events. +- Fix bug where we stopped sending events to remote home servers if a user from that home server left, even if there were some still in the room. +- Fix bugs in the state conflict resolution where it was incorrectly rejecting events. Webclient: -: - Display room names and topics. - - Allow setting/editing of room names and topics. - - Display information about rooms on the main page. - - Handle ban and kick events in real time. - - VoIP UI and reliability improvements. - - Add glare support for VoIP. - - Improvements to initial startup speed. - - Don\'t display duplicate join events. - - Local echo of messages. - - Differentiate sending and sent of local echo. - - Various minor bug fixes. +- Display room names and topics. +- Allow setting/editing of room names and topics. +- Display information about rooms on the main page. +- Handle ban and kick events in real time. +- VoIP UI and reliability improvements. +- Add glare support for VoIP. +- Improvements to initial startup speed. +- Don\'t display duplicate join events. +- Local echo of messages. +- Differentiate sending and sent of local echo. +- Various minor bug fixes. Changes in synapse 0.2.2 (2014-09-06) ===================================== Homeserver: -: - When the server returns state events it now also includes the previous content. - - Add support for inviting people when creating a new room. - - Make the homeserver inform the room via m.room.aliases when a new alias is added for a room. - - Validate m.room.power\_level events. +- When the server returns state events it now also includes the previous content. +- Add support for inviting people when creating a new room. +- Make the homeserver inform the room via m.room.aliases when a new alias is added for a room. +- Validate m.room.power\_level events. Webclient: -: - Add support for captchas on registration. - - Handle m.room.aliases events. - - Asynchronously send messages and show a local echo. - - Inform the UI when a message failed to send. - - Only autoscroll on receiving a new message if the user was already at the bottom of the screen. - - Add support for ban/kick reasons. +- Add support for captchas on registration. +- Handle m.room.aliases events. +- Asynchronously send messages and show a local echo. +- Inform the UI when a message failed to send. +- Only autoscroll on receiving a new message if the user was already at the bottom of the screen. +- Add support for ban/kick reasons. Changes in synapse 0.2.1 (2014-09-03) ===================================== Homeserver: -: - Added support for signing up with a third party id. - - Add synctl scripts. - - Added rate limiting. - - Add option to change the external address the content repo uses. - - Presence bug fixes. +- Added support for signing up with a third party id. +- Add synctl scripts. +- Added rate limiting. +- Add option to change the external address the content repo uses. +- Presence bug fixes. Webclient: -: - Added support for signing up with a third party id. - - Added support for banning and kicking users. - - Added support for displaying and setting ops. - - Added support for room names. - - Fix bugs with room membership event display. +- Added support for signing up with a third party id. +- Added support for banning and kicking users. +- Added support for displaying and setting ops. +- Added support for room names. +- Fix bugs with room membership event display. Changes in synapse 0.2.0 (2014-09-02) ===================================== @@ -6299,36 +6867,36 @@ This update changes many configuration options, updates the database schema and Homeserver: -: - Require SSL for server-server connections. - - Add SSL listener for client-server connections. - - Add ability to use config files. - - Add support for kicking/banning and power levels. - - Allow setting of room names and topics on creation. - - Change presence to include last seen time of the user. - - Change url path prefix to /\_matrix/\... - - Bug fixes to presence. +- Require SSL for server-server connections. +- Add SSL listener for client-server connections. +- Add ability to use config files. +- Add support for kicking/banning and power levels. +- Allow setting of room names and topics on creation. +- Change presence to include last seen time of the user. +- Change url path prefix to /\_matrix/\... +- Bug fixes to presence. Webclient: -: - Reskin the CSS for registration and login. - - Various improvements to rooms CSS. - - Support changes in client-server API. - - Bug fixes to VOIP UI. - - Various bug fixes to handling of changes to room member list. +- Reskin the CSS for registration and login. +- Various improvements to rooms CSS. +- Support changes in client-server API. +- Bug fixes to VOIP UI. +- Various bug fixes to handling of changes to room member list. Changes in synapse 0.1.2 (2014-08-29) ===================================== Webclient: -: - Add basic call state UI for VoIP calls. +- Add basic call state UI for VoIP calls. Changes in synapse 0.1.1 (2014-08-29) ===================================== Homeserver: -: - Fix bug that caused the event stream to not notify some clients about changes. +- Fix bug that caused the event stream to not notify some clients about changes. Changes in synapse 0.1.0 (2014-08-29) ===================================== @@ -6337,26 +6905,22 @@ Presence has been reenabled in this release. Homeserver: -: - - - Update client to server API, including: - - : - Use a more consistent url scheme. - - Provide more useful information in the initial sync api. - - - Change the presence handling to be much more efficient. - - Change the presence server to server API to not require explicit polling of all users who share a room with a user. - - Fix races in the event streaming logic. +- Update client to server API, including: + - Use a more consistent url scheme. + - Provide more useful information in the initial sync api. +- Change the presence handling to be much more efficient. +- Change the presence server to server API to not require explicit polling of all users who share a room with a user. +- Fix races in the event streaming logic. Webclient: -: - Update to use new client to server API. - - Add basic VOIP support. - - Add idle timers that change your status to away. - - Add recent rooms column when viewing a room. - - Various network efficiency improvements. - - Add basic mobile browser support. - - Add a settings page. +- Update to use new client to server API. +- Add basic VOIP support. +- Add idle timers that change your status to away. +- Add recent rooms column when viewing a room. +- Various network efficiency improvements. +- Add basic mobile browser support. +- Add a settings page. Changes in synapse 0.0.1 (2014-08-22) ===================================== @@ -6365,26 +6929,26 @@ Presence has been disabled in this release due to a bug that caused the homeserv Homeserver: -: - Completely change the database schema to support generic event types. - - Improve presence reliability. - - Improve reliability of joining remote rooms. - - Fix bug where room join events were duplicated. - - Improve initial sync API to return more information to the client. - - Stop generating fake messages for room membership events. +- Completely change the database schema to support generic event types. +- Improve presence reliability. +- Improve reliability of joining remote rooms. +- Fix bug where room join events were duplicated. +- Improve initial sync API to return more information to the client. +- Stop generating fake messages for room membership events. Webclient: -: - Add tab completion of names. - - Add ability to upload and send images. - - Add profile pages. - - Improve CSS layout of room. - - Disambiguate identical display names. - - Don\'t get remote users display names and avatars individually. - - Use the new initial sync API to reduce number of round trips to the homeserver. - - Change url scheme to use room aliases instead of room ids where known. - - Increase longpoll timeout. +- Add tab completion of names. +- Add ability to upload and send images. +- Add profile pages. +- Improve CSS layout of room. +- Disambiguate identical display names. +- Don\'t get remote users display names and avatars individually. +- Use the new initial sync API to reduce number of round trips to the homeserver. +- Change url scheme to use room aliases instead of room ids where known. +- Increase longpoll timeout. Changes in synapse 0.0.0 (2014-08-13) ===================================== -> - Initial alpha release +- Initial alpha release diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 524f82433d..1d7bb8f969 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md
@@ -46,7 +46,7 @@ locally. You'll need python 3.6 or later, and to install a number of tools: ``` # Install the dependencies -pip install -e ".[lint]" +pip install -e ".[lint,mypy]" # Run the linter script ./scripts-dev/lint.sh @@ -63,6 +63,10 @@ run-time: ./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder ``` +You can also provide the `-d` option, which will lint the files that have been +changed since the last git commit. This will often be significantly faster than +linting the whole codebase. + Before pushing new changes, ensure they don't produce linting errors. Commit any files that were corrected. @@ -152,6 +156,24 @@ directory, you will need both a regular newsfragment *and* an entry in the debian changelog. (Though typically such changes should be submitted as two separate pull requests.) +## Documentation + +There is a growing amount of documentation located in the [docs](docs) +directory. This documentation is intended primarily for sysadmins running their +own Synapse instance, as well as developers interacting externally with +Synapse. [docs/dev](docs/dev) exists primarily to house documentation for +Synapse developers. [docs/admin_api](docs/admin_api) houses documentation +regarding Synapse's Admin API, which is used mostly by sysadmins and external +service developers. + +New files added to both folders should be written in [Github-Flavoured +Markdown](https://guides.github.com/features/mastering-markdown/), and attempts +should be made to migrate existing documents to markdown where possible. + +Some documentation also exists in [Synapse's Github +Wiki](https://github.com/matrix-org/synapse/wiki), although this is primarily +contributed to by community authors. + ## Sign off In order to have a concrete record that your contribution is intentional diff --git a/INSTALL.md b/INSTALL.md
index 22f7b7c029..598ddceb8c 100644 --- a/INSTALL.md +++ b/INSTALL.md
@@ -1,19 +1,44 @@ -- [Choosing your server name](#choosing-your-server-name) -- [Picking a database engine](#picking-a-database-engine) -- [Installing Synapse](#installing-synapse) - - [Installing from source](#installing-from-source) - - [Platform-Specific Instructions](#platform-specific-instructions) - - [Prebuilt packages](#prebuilt-packages) -- [Setting up Synapse](#setting-up-synapse) - - [TLS certificates](#tls-certificates) - - [Client Well-Known URI](#client-well-known-uri) - - [Email](#email) - - [Registering a user](#registering-a-user) - - [Setting up a TURN server](#setting-up-a-turn-server) - - [URL previews](#url-previews) -- [Troubleshooting Installation](#troubleshooting-installation) - -# Choosing your server name +# Installation Instructions + +There are 3 steps to follow under **Installation Instructions**. + +- [Installation Instructions](#installation-instructions) + - [Choosing your server name](#choosing-your-server-name) + - [Installing Synapse](#installing-synapse) + - [Installing from source](#installing-from-source) + - [Platform-Specific Instructions](#platform-specific-instructions) + - [Debian/Ubuntu/Raspbian](#debianubunturaspbian) + - [ArchLinux](#archlinux) + - [CentOS/Fedora](#centosfedora) + - [macOS](#macos) + - [OpenSUSE](#opensuse) + - [OpenBSD](#openbsd) + - [Windows](#windows) + - [Prebuilt packages](#prebuilt-packages) + - [Docker images and Ansible playbooks](#docker-images-and-ansible-playbooks) + - [Debian/Ubuntu](#debianubuntu) + - [Matrix.org packages](#matrixorg-packages) + - [Downstream Debian packages](#downstream-debian-packages) + - [Downstream Ubuntu packages](#downstream-ubuntu-packages) + - [Fedora](#fedora) + - [OpenSUSE](#opensuse-1) + - [SUSE Linux Enterprise Server](#suse-linux-enterprise-server) + - [ArchLinux](#archlinux-1) + - [Void Linux](#void-linux) + - [FreeBSD](#freebsd) + - [OpenBSD](#openbsd-1) + - [NixOS](#nixos) + - [Setting up Synapse](#setting-up-synapse) + - [Using PostgreSQL](#using-postgresql) + - [TLS certificates](#tls-certificates) + - [Client Well-Known URI](#client-well-known-uri) + - [Email](#email) + - [Registering a user](#registering-a-user) + - [Setting up a TURN server](#setting-up-a-turn-server) + - [URL previews](#url-previews) + - [Troubleshooting Installation](#troubleshooting-installation) + +## Choosing your server name It is important to choose the name for your server before you install Synapse, because it cannot be changed later. @@ -29,35 +54,16 @@ that your email address is probably `user@example.com` rather than `user@email.example.com`) - but doing so may require more advanced setup: see [Setting up Federation](docs/federate.md). -# Picking a database engine +## Installing Synapse -Synapse offers two database engines: - * [PostgreSQL](https://www.postgresql.org) - * [SQLite](https://sqlite.org/) - -Almost all installations should opt to use PostgreSQL. Advantages include: - -* significant performance improvements due to the superior threading and - caching model, smarter query optimiser -* allowing the DB to be run on separate hardware - -For information on how to install and use PostgreSQL, please see -[docs/postgres.md](docs/postgres.md) - -By default Synapse uses SQLite and in doing so trades performance for convenience. -SQLite is only recommended in Synapse for testing purposes or for servers with -light workloads. - -# Installing Synapse - -## Installing from source +### Installing from source (Prebuilt packages are available for some platforms - see [Prebuilt packages](#prebuilt-packages).) System requirements: - POSIX-compliant system (tested on Linux & OS X) -- Python 3.5.2 or later, up to Python 3.8. +- Python 3.5.2 or later, up to Python 3.9. - At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org Synapse is written in Python but some of the libraries it uses are written in @@ -68,7 +74,7 @@ these on various platforms. To install the Synapse homeserver run: -``` +```sh mkdir -p ~/synapse virtualenv -p python3 ~/synapse/env source ~/synapse/env/bin/activate @@ -85,7 +91,7 @@ prefer. This Synapse installation can then be later upgraded by using pip again with the update flag: -``` +```sh source ~/synapse/env/bin/activate pip install -U matrix-synapse ``` @@ -93,7 +99,7 @@ pip install -U matrix-synapse Before you can start Synapse, you will need to generate a configuration file. To do this, run (in your virtualenv, as before): -``` +```sh cd ~/synapse python -m synapse.app.homeserver \ --server-name my.domain.name \ @@ -111,45 +117,43 @@ wise to back them up somewhere safe. (If, for whatever reason, you do need to change your homeserver's keys, you may find that other homeserver have the old key cached. If you update the signing key, you should change the name of the key in the `<server name>.signing.key` file (the second word) to something -different. See the -[spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys) -for more information on key management). +different. See the [spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys) for more information on key management). To actually run your new homeserver, pick a working directory for Synapse to run (e.g. `~/synapse`), and: -``` +```sh cd ~/synapse source env/bin/activate synctl start ``` -### Platform-Specific Instructions +#### Platform-Specific Instructions -#### Debian/Ubuntu/Raspbian +##### Debian/Ubuntu/Raspbian Installing prerequisites on Ubuntu or Debian: -``` -sudo apt-get install build-essential python3-dev libffi-dev \ +```sh +sudo apt install build-essential python3-dev libffi-dev \ python3-pip python3-setuptools sqlite3 \ libssl-dev virtualenv libjpeg-dev libxslt1-dev ``` -#### ArchLinux +##### ArchLinux Installing prerequisites on ArchLinux: -``` +```sh sudo pacman -S base-devel python python-pip \ python-setuptools python-virtualenv sqlite3 ``` -#### CentOS/Fedora +##### CentOS/Fedora Installing prerequisites on CentOS 8 or Fedora>26: -``` +```sh sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ libwebp-devel tk-devel redhat-rpm-config \ python3-virtualenv libffi-devel openssl-devel @@ -158,7 +162,7 @@ sudo dnf groupinstall "Development Tools" Installing prerequisites on CentOS 7 or Fedora<=25: -``` +```sh sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \ python3-virtualenv libffi-devel openssl-devel @@ -170,11 +174,11 @@ uses SQLite 3.7. You may be able to work around this by installing a more recent SQLite version, but it is recommended that you instead use a Postgres database: see [docs/postgres.md](docs/postgres.md). -#### macOS +##### macOS Installing prerequisites on macOS: -``` +```sh xcode-select --install sudo easy_install pip sudo pip install virtualenv @@ -184,22 +188,22 @@ brew install pkg-config libffi On macOS Catalina (10.15) you may need to explicitly install OpenSSL via brew and inform `pip` about it so that `psycopg2` builds: -``` +```sh brew install openssl@1.1 export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/ ``` -#### OpenSUSE +##### OpenSUSE Installing prerequisites on openSUSE: -``` +```sh sudo zypper in -t pattern devel_basis sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \ python-devel libffi-devel libopenssl-devel libjpeg62-devel ``` -#### OpenBSD +##### OpenBSD A port of Synapse is available under `net/synapse`. The filesystem underlying the homeserver directory (defaults to `/var/synapse`) has to be @@ -213,73 +217,72 @@ mounted with `wxallowed` (cf. `mount(8)`). Creating a `WRKOBJDIR` for building python under `/usr/local` (which on a default OpenBSD installation is mounted with `wxallowed`): -``` +```sh doas mkdir /usr/local/pobj_wxallowed ``` Assuming `PORTS_PRIVSEP=Yes` (cf. `bsd.port.mk(5)`) and `SUDO=doas` are configured in `/etc/mk.conf`: -``` +```sh doas chown _pbuild:_pbuild /usr/local/pobj_wxallowed ``` Setting the `WRKOBJDIR` for building python: -``` +```sh echo WRKOBJDIR_lang/python/3.7=/usr/local/pobj_wxallowed \\nWRKOBJDIR_lang/python/2.7=/usr/local/pobj_wxallowed >> /etc/mk.conf ``` Building Synapse: -``` +```sh cd /usr/ports/net/synapse make install ``` -#### Windows +##### Windows If you wish to run or develop Synapse on Windows, the Windows Subsystem For Linux provides a Linux environment on Windows 10 which is capable of using the Debian, Fedora, or source installation methods. More information about WSL can -be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for -Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server +be found at <https://docs.microsoft.com/en-us/windows/wsl/install-win10> for +Windows 10 and <https://docs.microsoft.com/en-us/windows/wsl/install-on-server> for Windows Server. -## Prebuilt packages +### Prebuilt packages As an alternative to installing from source, prebuilt packages are available for a number of platforms. -### Docker images and Ansible playbooks +#### Docker images and Ansible playbooks There is an offical synapse image available at -https://hub.docker.com/r/matrixdotorg/synapse which can be used with +<https://hub.docker.com/r/matrixdotorg/synapse> which can be used with the docker-compose file available at [contrib/docker](contrib/docker). Further information on this including configuration options is available in the README on hub.docker.com. Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a Dockerfile to automate a synapse server in a single Docker image, at -https://hub.docker.com/r/avhost/docker-matrix/tags/ +<https://hub.docker.com/r/avhost/docker-matrix/tags/> Slavi Pantaleev has created an Ansible playbook, which installs the offical Docker image of Matrix Synapse along with many other Matrix-related services (Postgres database, Element, coturn, ma1sd, SSL support, etc.). For more details, see -https://github.com/spantaleev/matrix-docker-ansible-deploy - +<https://github.com/spantaleev/matrix-docker-ansible-deploy> -### Debian/Ubuntu +#### Debian/Ubuntu -#### Matrix.org packages +##### Matrix.org packages Matrix.org provides Debian/Ubuntu packages of the latest stable version of -Synapse via https://packages.matrix.org/debian/. They are available for Debian +Synapse via <https://packages.matrix.org/debian/>. They are available for Debian 9 (Stretch), Ubuntu 16.04 (Xenial), and later. To use them: -``` +```sh sudo apt install -y lsb-release wget apt-transport-https sudo wget -O /usr/share/keyrings/matrix-org-archive-keyring.gpg https://packages.matrix.org/debian/matrix-org-archive-keyring.gpg echo "deb [signed-by=/usr/share/keyrings/matrix-org-archive-keyring.gpg] https://packages.matrix.org/debian/ $(lsb_release -cs) main" | @@ -299,7 +302,7 @@ The fingerprint of the repository signing key (as shown by `gpg /usr/share/keyrings/matrix-org-archive-keyring.gpg`) is `AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`. -#### Downstream Debian packages +##### Downstream Debian packages We do not recommend using the packages from the default Debian `buster` repository at this time, as they are old and suffer from known security @@ -311,49 +314,49 @@ for information on how to use backports. If you are using Debian `sid` or testing, Synapse is available in the default repositories and it should be possible to install it simply with: -``` +```sh sudo apt install matrix-synapse ``` -#### Downstream Ubuntu packages +##### Downstream Ubuntu packages We do not recommend using the packages in the default Ubuntu repository at this time, as they are old and suffer from known security vulnerabilities. The latest version of Synapse can be installed from [our repository](#matrixorg-packages). -### Fedora +#### Fedora Synapse is in the Fedora repositories as `matrix-synapse`: -``` +```sh sudo dnf install matrix-synapse ``` Oleg Girko provides Fedora RPMs at -https://obs.infoserver.lv/project/monitor/matrix-synapse +<https://obs.infoserver.lv/project/monitor/matrix-synapse> -### OpenSUSE +#### OpenSUSE Synapse is in the OpenSUSE repositories as `matrix-synapse`: -``` +```sh sudo zypper install matrix-synapse ``` -### SUSE Linux Enterprise Server +#### SUSE Linux Enterprise Server Unofficial package are built for SLES 15 in the openSUSE:Backports:SLE-15 repository at -https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/ +<https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/> -### ArchLinux +#### ArchLinux The quickest way to get up and running with ArchLinux is probably with the community package -https://www.archlinux.org/packages/community/any/matrix-synapse/, which should pull in most of +<https://www.archlinux.org/packages/community/any/matrix-synapse/>, which should pull in most of the necessary dependencies. pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ): -``` +```sh sudo pip install --upgrade pip ``` @@ -362,28 +365,28 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly compile it under the right architecture. (This should not be needed if installing under virtualenv): -``` +```sh sudo pip uninstall py-bcrypt sudo pip install py-bcrypt ``` -### Void Linux +#### Void Linux Synapse can be found in the void repositories as 'synapse': -``` +```sh xbps-install -Su xbps-install -S synapse ``` -### FreeBSD +#### FreeBSD Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from: - - Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean` - - Packages: `pkg install py37-matrix-synapse` +- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean` +- Packages: `pkg install py37-matrix-synapse` -### OpenBSD +#### OpenBSD As of OpenBSD 6.7 Synapse is available as a pre-compiled binary. The filesystem underlying the homeserver directory (defaults to `/var/synapse`) has to be @@ -392,20 +395,35 @@ and mounting it to `/var/synapse` should be taken into consideration. Installing Synapse: -``` +```sh doas pkg_add synapse ``` -### NixOS +#### NixOS Robin Lambertz has packaged Synapse for NixOS at: -https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix +<https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix> -# Setting up Synapse +## Setting up Synapse Once you have installed synapse as above, you will need to configure it. -## TLS certificates +### Using PostgreSQL + +By default Synapse uses [SQLite](https://sqlite.org/) and in doing so trades performance for convenience. +SQLite is only recommended in Synapse for testing purposes or for servers with +very light workloads. + +Almost all installations should opt to use [PostgreSQL](https://www.postgresql.org). Advantages include: + +- significant performance improvements due to the superior threading and + caching model, smarter query optimiser +- allowing the DB to be run on separate hardware + +For information on how to install and use PostgreSQL in Synapse, please see +[docs/postgres.md](docs/postgres.md) + +### TLS certificates The default configuration exposes a single HTTP port on the local interface: `http://localhost:8008`. It is suitable for local testing, @@ -419,19 +437,19 @@ The recommended way to do so is to set up a reverse proxy on port Alternatively, you can configure Synapse to expose an HTTPS port. To do so, you will need to edit `homeserver.yaml`, as follows: -* First, under the `listeners` section, uncomment the configuration for the +- First, under the `listeners` section, uncomment the configuration for the TLS-enabled listener. (Remove the hash sign (`#`) at the start of each line). The relevant lines are like this: - ``` - - port: 8448 - type: http - tls: true - resources: - - names: [client, federation] +```yaml + - port: 8448 + type: http + tls: true + resources: + - names: [client, federation] ``` -* You will also need to uncomment the `tls_certificate_path` and +- You will also need to uncomment the `tls_certificate_path` and `tls_private_key_path` lines under the `TLS` section. You will need to manage provisioning of these certificates yourself — Synapse had built-in ACME support, but the ACMEv1 protocol Synapse implements is deprecated, not @@ -446,7 +464,7 @@ so, you will need to edit `homeserver.yaml`, as follows: For a more detailed guide to configuring your server for federation, see [federate.md](docs/federate.md). -## Client Well-Known URI +### Client Well-Known URI Setting up the client Well-Known URI is optional but if you set it up, it will allow users to enter their full username (e.g. `@user:<server_name>`) into clients @@ -457,7 +475,7 @@ about the actual homeserver URL you are using. The URL `https://<server_name>/.well-known/matrix/client` should return JSON in the following format. -``` +```json { "m.homeserver": { "base_url": "https://<matrix.example.com>" @@ -467,7 +485,7 @@ the following format. It can optionally contain identity server information as well. -``` +```json { "m.homeserver": { "base_url": "https://<matrix.example.com>" @@ -484,10 +502,11 @@ Cross-Origin Resource Sharing (CORS) headers. A recommended value would be view it. In nginx this would be something like: -``` + +```nginx location /.well-known/matrix/client { return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}'; - add_header Content-Type application/json; + default_type application/json; add_header Access-Control-Allow-Origin *; } ``` @@ -497,11 +516,11 @@ correctly. `public_baseurl` should be set to the URL that clients will use to connect to your server. This is the same URL you put for the `m.homeserver` `base_url` above. -``` +```yaml public_baseurl: "https://<matrix.example.com>" ``` -## Email +### Email It is desirable for Synapse to have the capability to send email. This allows Synapse to send password reset emails, send verifications when an email address @@ -516,7 +535,7 @@ and `notif_from` fields filled out. You may also need to set `smtp_user`, If email is not configured, password reset, registration and notifications via email will be disabled. -## Registering a user +### Registering a user The easiest way to create a new user is to do so from a client like [Element](https://element.io/). @@ -524,7 +543,7 @@ Alternatively you can do so from the command line if you have installed via pip. This can be done as follows: -``` +```sh $ source ~/synapse/env/bin/activate $ synctl start # if not already running $ register_new_matrix_user -c homeserver.yaml http://localhost:8008 @@ -542,12 +561,12 @@ value is generated by `--generate-config`), but it should be kept secret, as anyone with knowledge of it can register users, including admin accounts, on your server even if `enable_registration` is `false`. -## Setting up a TURN server +### Setting up a TURN server For reliable VoIP calls to be routed via this homeserver, you MUST configure a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details. -## URL previews +### URL previews Synapse includes support for previewing URLs, which is disabled by default. To turn it on you must enable the `url_preview_enabled: True` config parameter @@ -557,19 +576,18 @@ This is critical from a security perspective to stop arbitrary Matrix users spidering 'internal' URLs on your network. At the very least we recommend that your loopback and RFC1918 IP addresses are blacklisted. -This also requires the optional `lxml` and `netaddr` python dependencies to be -installed. This in turn requires the `libxml2` library to be available - on -Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for -your OS. +This also requires the optional `lxml` python dependency to be installed. This +in turn requires the `libxml2` library to be available - on Debian/Ubuntu this +means `apt-get install libxml2-dev`, or equivalent for your OS. -# Troubleshooting Installation +### Troubleshooting Installation `pip` seems to leak *lots* of memory during installation. For instance, a Linux host with 512MB of RAM may run out of memory whilst installing Twisted. If this happens, you will have to individually install the dependencies which are failing, e.g.: -``` +```sh pip install twisted ``` diff --git a/README.rst b/README.rst
index 4a189c8bc4..31ae5cc578 100644 --- a/README.rst +++ b/README.rst
@@ -1,10 +1,6 @@ -================ -Synapse |shield| -================ - -.. |shield| image:: https://img.shields.io/matrix/synapse:matrix.org?label=support&logo=matrix - :alt: (get support on #synapse:matrix.org) - :target: https://matrix.to/#/#synapse:matrix.org +========================================================= +Synapse |support| |development| |license| |pypi| |python| +========================================================= .. contents:: @@ -247,6 +243,8 @@ Then update the ``users`` table in the database:: Synapse Development =================== +Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org) + Before setting up a development environment for synapse, make sure you have the system dependencies (such as the python header files) installed - see `Installing from source <INSTALL.md#installing-from-source>`_. @@ -260,23 +258,27 @@ directory of your choice:: Synapse has a number of external dependencies, that are easiest to install using pip and a virtualenv:: - virtualenv -p python3 env - source env/bin/activate - python -m pip install --no-use-pep517 -e ".[all]" + python3 -m venv ./env + source ./env/bin/activate + pip install -e ".[all,test]" This will run a process of downloading and installing all the needed -dependencies into a virtual env. +dependencies into a virtual env. If any dependencies fail to install, +try installing the failing modules individually:: + + pip install -e "module-name" -Once this is done, you may wish to run Synapse's unit tests, to -check that everything is installed as it should be:: +Once this is done, you may wish to run Synapse's unit tests to +check that everything is installed correctly:: python -m twisted.trial tests -This should end with a 'PASSED' result:: +This should end with a 'PASSED' result (note that exact numbers will +differ):: - Ran 143 tests in 0.601s + Ran 1337 tests in 716.064s - PASSED (successes=143) + PASSED (skips=15, successes=1322) Running the Integration Tests ============================= @@ -290,19 +292,6 @@ Testing with SyTest is recommended for verifying that changes related to the Client-Server API are functioning correctly. See the `installation instructions <https://github.com/matrix-org/sytest#installing>`_ for details. -Building Internal API Documentation -=================================== - -Before building internal API documentation install sphinx and -sphinxcontrib-napoleon:: - - pip install sphinx - pip install sphinxcontrib-napoleon - -Building internal API documentation:: - - python setup.py build_sphinx - Troubleshooting =============== @@ -387,3 +376,23 @@ something like the following in their logs:: This is normally caused by a misconfiguration in your reverse-proxy. See `<docs/reverse_proxy.md>`_ and double-check that your settings are correct. + +.. |support| image:: https://img.shields.io/matrix/synapse:matrix.org?label=support&logo=matrix + :alt: (get support on #synapse:matrix.org) + :target: https://matrix.to/#/#synapse:matrix.org + +.. |development| image:: https://img.shields.io/matrix/synapse-dev:matrix.org?label=development&logo=matrix + :alt: (discuss development on #synapse-dev:matrix.org) + :target: https://matrix.to/#/#synapse-dev:matrix.org + +.. |license| image:: https://img.shields.io/github/license/matrix-org/synapse + :alt: (check license in LICENSE file) + :target: LICENSE + +.. |pypi| image:: https://img.shields.io/pypi/v/matrix-synapse + :alt: (latest version released on PyPi) + :target: https://pypi.org/project/matrix-synapse + +.. |python| image:: https://img.shields.io/pypi/pyversions/matrix-synapse + :alt: (supported python versions) + :target: https://pypi.org/project/matrix-synapse diff --git a/UPGRADE.rst b/UPGRADE.rst
index 49e86e628f..f750d17da2 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst
@@ -5,6 +5,16 @@ Before upgrading check if any special steps are required to upgrade from the version you currently have installed to the current version of Synapse. The extra instructions that may be required are listed later in this document. +* Check that your versions of Python and PostgreSQL are still supported. + + Synapse follows upstream lifecycles for `Python`_ and `PostgreSQL`_, and + removes support for versions which are no longer maintained. + + The website https://endoflife.date also offers convenient summaries. + + .. _Python: https://devguide.python.org/devcycle/#end-of-life-branches + .. _PostgreSQL: https://www.postgresql.org/support/versioning/ + * If Synapse was installed using `prebuilt packages <INSTALL.md#prebuilt-packages>`_, you will need to follow the normal process for upgrading those packages. @@ -75,6 +85,124 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.25.0 +==================== + +Last release supporting Python 3.5 +---------------------------------- + +This is the last release of Synapse which guarantees support with Python 3.5, +which passed its upstream End of Life date several months ago. + +We will attempt to maintain support through March 2021, but without guarantees. + +In the future, Synapse will follow upstream schedules for ending support of +older versions of Python and PostgreSQL. Please upgrade to at least Python 3.6 +and PostgreSQL 9.6 as soon as possible. + +Blacklisting IP ranges +---------------------- + +Synapse v1.25.0 includes new settings, ``ip_range_blacklist`` and +``ip_range_whitelist``, for controlling outgoing requests from Synapse for federation, +identity servers, push, and for checking key validity for third-party invite events. +The previous setting, ``federation_ip_range_blacklist``, is deprecated. The new +``ip_range_blacklist`` defaults to private IP ranges if it is not defined. + +If you have never customised ``federation_ip_range_blacklist`` it is recommended +that you remove that setting. + +If you have customised ``federation_ip_range_blacklist`` you should update the +setting name to ``ip_range_blacklist``. + +If you have a custom push server that is reached via private IP space you may +need to customise ``ip_range_blacklist`` or ``ip_range_whitelist``. + +Upgrading to v1.24.0 +==================== + +Custom OpenID Connect mapping provider breaking change +------------------------------------------------------ + +This release allows the OpenID Connect mapping provider to perform normalisation +of the localpart of the Matrix ID. This allows for the mapping provider to +specify different algorithms, instead of the [default way](https://matrix.org/docs/spec/appendices#mapping-from-other-character-sets). + +If your Synapse configuration uses a custom mapping provider +(`oidc_config.user_mapping_provider.module` is specified and not equal to +`synapse.handlers.oidc_handler.JinjaOidcMappingProvider`) then you *must* ensure +that `map_user_attributes` of the mapping provider performs some normalisation +of the `localpart` returned. To match previous behaviour you can use the +`map_username_to_mxid_localpart` function provided by Synapse. An example is +shown below: + +.. code-block:: python + + from synapse.types import map_username_to_mxid_localpart + + class MyMappingProvider: + def map_user_attributes(self, userinfo, token): + # ... your custom logic ... + sso_user_id = ... + localpart = map_username_to_mxid_localpart(sso_user_id) + + return {"localpart": localpart} + +Removal historical Synapse Admin API +------------------------------------ + +Historically, the Synapse Admin API has been accessible under: + +* ``/_matrix/client/api/v1/admin`` +* ``/_matrix/client/unstable/admin`` +* ``/_matrix/client/r0/admin`` +* ``/_synapse/admin/v1`` + +The endpoints with ``/_matrix/client/*`` prefixes have been removed as of v1.24.0. +The Admin API is now only accessible under: + +* ``/_synapse/admin/v1`` + +The only exception is the `/admin/whois` endpoint, which is +`also available via the client-server API <https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid>`_. + +The deprecation of the old endpoints was announced with Synapse 1.20.0 (released +on 2020-09-22) and makes it easier for homeserver admins to lock down external +access to the Admin API endpoints. + +Upgrading to v1.23.0 +==================== + +Structured logging configuration breaking changes +------------------------------------------------- + +This release deprecates use of the ``structured: true`` logging configuration for +structured logging. If your logging configuration contains ``structured: true`` +then it should be modified based on the `structured logging documentation +<https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md>`_. + +The ``structured`` and ``drains`` logging options are now deprecated and should +be replaced by standard logging configuration of ``handlers`` and ``formatters``. + +A future will release of Synapse will make using ``structured: true`` an error. + +Upgrading to v1.22.0 +==================== + +ThirdPartyEventRules breaking changes +------------------------------------- + +This release introduces a backwards-incompatible change to modules making use of +``ThirdPartyEventRules`` in Synapse. If you make use of a module defined under the +``third_party_event_rules`` config option, please make sure it is updated to handle +the below change: + +The ``http_client`` argument is no longer passed to modules as they are initialised. Instead, +modules are expected to make use of the ``http_client`` property on the ``ModuleApi`` class. +Modules are now passed a ``module_api`` argument during initialisation, which is an instance of +``ModuleApi``. ``ModuleApi`` instances have a ``http_client`` property which acts the same as +the ``http_client`` argument previously passed to ``ThirdPartyEventRules`` modules. + Upgrading to v1.21.0 ==================== diff --git a/changelog.d/8530.bugfix b/changelog.d/8530.bugfix deleted file mode 100644
index 443d88424e..0000000000 --- a/changelog.d/8530.bugfix +++ /dev/null
@@ -1 +0,0 @@ -Fix rare bug where sending an event would fail due to a racey assertion. diff --git a/contrib/grafana/README.md b/contrib/grafana/README.md
index ca780d412e..4608793394 100644 --- a/contrib/grafana/README.md +++ b/contrib/grafana/README.md
@@ -3,4 +3,4 @@ 0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/ 1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md 2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/ -3. Set up additional recording rules +3. Set up required recording rules. https://github.com/matrix-org/synapse/tree/master/contrib/prometheus diff --git a/contrib/prometheus/README.md b/contrib/prometheus/README.md
index e646cb7ea7..b3f23bcc80 100644 --- a/contrib/prometheus/README.md +++ b/contrib/prometheus/README.md
@@ -20,6 +20,7 @@ Add a new job to the main prometheus.conf file: ``` ### for Prometheus v2 + Add a new job to the main prometheus.yml file: ```yaml @@ -29,14 +30,17 @@ Add a new job to the main prometheus.yml file: scheme: "https" static_configs: - - targets: ['SERVER.LOCATION:PORT'] + - targets: ["my.server.here:port"] ``` +An example of a Prometheus configuration with workers can be found in +[metrics-howto.md](https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md). + To use `synapse.rules` add ```yaml - rule_files: - - "/PATH/TO/synapse-v2.rules" + rule_files: + - "/PATH/TO/synapse-v2.rules" ``` Metrics are disabled by default when running synapse; they must be enabled diff --git a/contrib/prometheus/consoles/synapse.html b/contrib/prometheus/consoles/synapse.html
index 69aa87f85e..cd9ad15231 100644 --- a/contrib/prometheus/consoles/synapse.html +++ b/contrib/prometheus/consoles/synapse.html
@@ -9,7 +9,7 @@ new PromConsole.Graph({ node: document.querySelector("#process_resource_utime"), expr: "rate(process_cpu_seconds_total[2m]) * 100", - name: "[[job]]", + name: "[[job]]-[[index]]", min: 0, max: 100, renderer: "line", @@ -22,12 +22,12 @@ new PromConsole.Graph({ </script> <h3>Memory</h3> -<div id="process_resource_maxrss"></div> +<div id="process_resident_memory_bytes"></div> <script> new PromConsole.Graph({ - node: document.querySelector("#process_resource_maxrss"), - expr: "process_psutil_rss:max", - name: "Maxrss", + node: document.querySelector("#process_resident_memory_bytes"), + expr: "process_resident_memory_bytes", + name: "[[job]]-[[index]]", min: 0, renderer: "line", height: 150, @@ -43,8 +43,8 @@ new PromConsole.Graph({ <script> new PromConsole.Graph({ node: document.querySelector("#process_fds"), - expr: "process_open_fds{job='synapse'}", - name: "FDs", + expr: "process_open_fds", + name: "[[job]]-[[index]]", min: 0, renderer: "line", height: 150, @@ -62,8 +62,8 @@ new PromConsole.Graph({ <script> new PromConsole.Graph({ node: document.querySelector("#reactor_total_time"), - expr: "rate(python_twisted_reactor_tick_time:total[2m]) / 1000", - name: "time", + expr: "rate(python_twisted_reactor_tick_time_sum[2m])", + name: "[[job]]-[[index]]", max: 1, min: 0, renderer: "area", @@ -80,8 +80,8 @@ new PromConsole.Graph({ <script> new PromConsole.Graph({ node: document.querySelector("#reactor_average_time"), - expr: "rate(python_twisted_reactor_tick_time:total[2m]) / rate(python_twisted_reactor_tick_time:count[2m]) / 1000", - name: "time", + expr: "rate(python_twisted_reactor_tick_time_sum[2m]) / rate(python_twisted_reactor_tick_time_count[2m])", + name: "[[job]]-[[index]]", min: 0, renderer: "line", height: 150, @@ -97,14 +97,14 @@ new PromConsole.Graph({ <script> new PromConsole.Graph({ node: document.querySelector("#reactor_pending_calls"), - expr: "rate(python_twisted_reactor_pending_calls:total[30s])/rate(python_twisted_reactor_pending_calls:count[30s])", - name: "calls", + expr: "rate(python_twisted_reactor_pending_calls_sum[30s]) / rate(python_twisted_reactor_pending_calls_count[30s])", + name: "[[job]]-[[index]]", min: 0, renderer: "line", height: 150, yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, - yTitle: "Pending Cals" + yTitle: "Pending Calls" }) </script> @@ -115,7 +115,7 @@ new PromConsole.Graph({ <script> new PromConsole.Graph({ node: document.querySelector("#synapse_storage_query_time"), - expr: "rate(synapse_storage_query_time:count[2m])", + expr: "sum(rate(synapse_storage_query_time_count[2m])) by (verb)", name: "[[verb]]", yAxisFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix, yHoverFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix, @@ -129,8 +129,8 @@ new PromConsole.Graph({ <script> new PromConsole.Graph({ node: document.querySelector("#synapse_storage_transaction_time"), - expr: "rate(synapse_storage_transaction_time:count[2m])", - name: "[[desc]]", + expr: "topk(10, rate(synapse_storage_transaction_time_count[2m]))", + name: "[[job]]-[[index]] [[desc]]", min: 0, yAxisFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix, yHoverFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix, @@ -140,12 +140,12 @@ new PromConsole.Graph({ </script> <h3>Transaction execution time</h3> -<div id="synapse_storage_transactions_time_msec"></div> +<div id="synapse_storage_transactions_time_sec"></div> <script> new PromConsole.Graph({ - node: document.querySelector("#synapse_storage_transactions_time_msec"), - expr: "rate(synapse_storage_transaction_time:total[2m]) / 1000", - name: "[[desc]]", + node: document.querySelector("#synapse_storage_transactions_time_sec"), + expr: "rate(synapse_storage_transaction_time_sum[2m])", + name: "[[job]]-[[index]] [[desc]]", min: 0, yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, @@ -154,34 +154,33 @@ new PromConsole.Graph({ }) </script> -<h3>Database scheduling latency</h3> -<div id="synapse_storage_schedule_time"></div> +<h3>Average time waiting for database connection</h3> +<div id="synapse_storage_avg_waiting_time"></div> <script> new PromConsole.Graph({ - node: document.querySelector("#synapse_storage_schedule_time"), - expr: "rate(synapse_storage_schedule_time:total[2m]) / 1000", - name: "Total latency", + node: document.querySelector("#synapse_storage_avg_waiting_time"), + expr: "rate(synapse_storage_schedule_time_sum[2m]) / rate(synapse_storage_schedule_time_count[2m])", + name: "[[job]]-[[index]]", min: 0, yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, - yUnits: "s/s", - yTitle: "Usage" + yUnits: "s", + yTitle: "Time" }) </script> -<h3>Cache hit ratio</h3> -<div id="synapse_cache_ratio"></div> +<h3>Cache request rate</h3> +<div id="synapse_cache_request_rate"></div> <script> new PromConsole.Graph({ - node: document.querySelector("#synapse_cache_ratio"), - expr: "rate(synapse_util_caches_cache:total[2m]) * 100", - name: "[[name]]", + node: document.querySelector("#synapse_cache_request_rate"), + expr: "rate(synapse_util_caches_cache:total[2m])", + name: "[[job]]-[[index]] [[name]]", min: 0, - max: 100, yAxisFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix, yHoverFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix, - yUnits: "%", - yTitle: "Percentage" + yUnits: "rps", + yTitle: "Cache request rate" }) </script> @@ -191,7 +190,7 @@ new PromConsole.Graph({ new PromConsole.Graph({ node: document.querySelector("#synapse_cache_size"), expr: "synapse_util_caches_cache:size", - name: "[[name]]", + name: "[[job]]-[[index]] [[name]]", yAxisFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix, yHoverFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix, yUnits: "", @@ -206,8 +205,8 @@ new PromConsole.Graph({ <script> new PromConsole.Graph({ node: document.querySelector("#synapse_http_server_request_count_servlet"), - expr: "rate(synapse_http_server_request_count:servlet[2m])", - name: "[[servlet]]", + expr: "rate(synapse_http_server_in_flight_requests_count[2m])", + name: "[[job]]-[[index]] [[method]] [[servlet]]", yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, yUnits: "req/s", @@ -219,8 +218,8 @@ new PromConsole.Graph({ <script> new PromConsole.Graph({ node: document.querySelector("#synapse_http_server_request_count_servlet_minus_events"), - expr: "rate(synapse_http_server_request_count:servlet{servlet!=\"EventStreamRestServlet\", servlet!=\"SyncRestServlet\"}[2m])", - name: "[[servlet]]", + expr: "rate(synapse_http_server_in_flight_requests_count{servlet!=\"EventStreamRestServlet\", servlet!=\"SyncRestServlet\"}[2m])", + name: "[[job]]-[[index]] [[method]] [[servlet]]", yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, yUnits: "req/s", @@ -233,8 +232,8 @@ new PromConsole.Graph({ <script> new PromConsole.Graph({ node: document.querySelector("#synapse_http_server_response_time_avg"), - expr: "rate(synapse_http_server_response_time_seconds[2m]) / rate(synapse_http_server_response_count[2m]) / 1000", - name: "[[servlet]]", + expr: "rate(synapse_http_server_response_time_seconds_sum[2m]) / rate(synapse_http_server_response_count[2m])", + name: "[[job]]-[[index]] [[servlet]]", yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, yUnits: "s/req", @@ -277,7 +276,7 @@ new PromConsole.Graph({ new PromConsole.Graph({ node: document.querySelector("#synapse_http_server_response_ru_utime"), expr: "rate(synapse_http_server_response_ru_utime_seconds[2m])", - name: "[[servlet]]", + name: "[[job]]-[[index]] [[servlet]]", yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, yUnits: "s/s", @@ -292,7 +291,7 @@ new PromConsole.Graph({ new PromConsole.Graph({ node: document.querySelector("#synapse_http_server_response_db_txn_duration"), expr: "rate(synapse_http_server_response_db_txn_duration_seconds[2m])", - name: "[[servlet]]", + name: "[[job]]-[[index]] [[servlet]]", yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, yUnits: "s/s", @@ -306,8 +305,8 @@ new PromConsole.Graph({ <script> new PromConsole.Graph({ node: document.querySelector("#synapse_http_server_send_time_avg"), - expr: "rate(synapse_http_server_response_time_second{servlet='RoomSendEventRestServlet'}[2m]) / rate(synapse_http_server_response_count{servlet='RoomSendEventRestServlet'}[2m]) / 1000", - name: "[[servlet]]", + expr: "rate(synapse_http_server_response_time_seconds_sum{servlet='RoomSendEventRestServlet'}[2m]) / rate(synapse_http_server_response_count{servlet='RoomSendEventRestServlet'}[2m])", + name: "[[job]]-[[index]] [[servlet]]", yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, yUnits: "s/req", @@ -323,7 +322,7 @@ new PromConsole.Graph({ new PromConsole.Graph({ node: document.querySelector("#synapse_federation_client_sent"), expr: "rate(synapse_federation_client_sent[2m])", - name: "[[type]]", + name: "[[job]]-[[index]] [[type]]", yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, yUnits: "req/s", @@ -337,7 +336,7 @@ new PromConsole.Graph({ new PromConsole.Graph({ node: document.querySelector("#synapse_federation_server_received"), expr: "rate(synapse_federation_server_received[2m])", - name: "[[type]]", + name: "[[job]]-[[index]] [[type]]", yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, yUnits: "req/s", @@ -367,7 +366,7 @@ new PromConsole.Graph({ new PromConsole.Graph({ node: document.querySelector("#synapse_notifier_listeners"), expr: "synapse_notifier_listeners", - name: "listeners", + name: "[[job]]-[[index]]", min: 0, yAxisFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix, yHoverFormatter: PromConsole.NumberFormatter.humanizeNoSmallPrefix, @@ -382,7 +381,7 @@ new PromConsole.Graph({ new PromConsole.Graph({ node: document.querySelector("#synapse_notifier_notified_events"), expr: "rate(synapse_notifier_notified_events[2m])", - name: "events", + name: "[[job]]-[[index]]", yAxisFormatter: PromConsole.NumberFormatter.humanize, yHoverFormatter: PromConsole.NumberFormatter.humanize, yUnits: "events/s", diff --git a/contrib/prometheus/synapse-v2.rules b/contrib/prometheus/synapse-v2.rules
index 6ccca2daaf..7e405bf7f0 100644 --- a/contrib/prometheus/synapse-v2.rules +++ b/contrib/prometheus/synapse-v2.rules
@@ -58,3 +58,21 @@ groups: labels: type: "PDU" expr: 'synapse_federation_transaction_queue_pending_pdus + 0' + + - record: synapse_storage_events_persisted_by_source_type + expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_type="remote"}) + labels: + type: remote + - record: synapse_storage_events_persisted_by_source_type + expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity="*client*",origin_type="local"}) + labels: + type: local + - record: synapse_storage_events_persisted_by_source_type + expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity!="*client*",origin_type="local"}) + labels: + type: bridges + - record: synapse_storage_events_persisted_by_event_type + expr: sum without(origin_entity, origin_type) (synapse_storage_events_persisted_events_sep) + - record: synapse_storage_events_persisted_by_origin + expr: sum without(type) (synapse_storage_events_persisted_events_sep) + diff --git a/debian/changelog b/debian/changelog
index eeafd4f50a..609436bf75 100644 --- a/debian/changelog +++ b/debian/changelog
@@ -1,3 +1,51 @@ +matrix-synapse-py3 (1.25.0) stable; urgency=medium + + [ Dan Callahan ] + * Update dependencies to account for the removal of the transitional + dh-systemd package from Debian Bullseye. + + [ Synapse Packaging team ] + * New synapse release 1.25.0. + + -- Synapse Packaging team <packages@matrix.org> Wed, 13 Jan 2021 10:14:55 +0000 + +matrix-synapse-py3 (1.24.0) stable; urgency=medium + + * New synapse release 1.24.0. + + -- Synapse Packaging team <packages@matrix.org> Wed, 09 Dec 2020 10:14:30 +0000 + +matrix-synapse-py3 (1.23.1) stable; urgency=medium + + * New synapse release 1.23.1. + + -- Synapse Packaging team <packages@matrix.org> Wed, 09 Dec 2020 10:40:39 +0000 + +matrix-synapse-py3 (1.23.0) stable; urgency=medium + + * New synapse release 1.23.0. + + -- Synapse Packaging team <packages@matrix.org> Wed, 18 Nov 2020 11:41:28 +0000 + +matrix-synapse-py3 (1.22.1) stable; urgency=medium + + * New synapse release 1.22.1. + + -- Synapse Packaging team <packages@matrix.org> Fri, 30 Oct 2020 15:25:37 +0000 + +matrix-synapse-py3 (1.22.0) stable; urgency=medium + + * New synapse release 1.22.0. + + -- Synapse Packaging team <packages@matrix.org> Tue, 27 Oct 2020 12:07:12 +0000 + +matrix-synapse-py3 (1.21.2) stable; urgency=medium + + [ Synapse Packaging team ] + * New synapse release 1.21.2. + + -- Synapse Packaging team <packages@matrix.org> Thu, 15 Oct 2020 09:23:27 -0400 + matrix-synapse-py3 (1.21.1) stable; urgency=medium [ Synapse Packaging team ] diff --git a/debian/control b/debian/control
index bae14b41e4..b10401be43 100644 --- a/debian/control +++ b/debian/control
@@ -3,9 +3,11 @@ Section: contrib/python Priority: extra Maintainer: Synapse Packaging team <packages@matrix.org> # keep this list in sync with the build dependencies in docker/Dockerfile-dhvirtualenv. +# TODO: Remove the dependency on dh-systemd after dropping support for Ubuntu xenial +# On all other supported releases, it's merely a transitional package which +# does nothing but depends on debhelper (> 9.20160709) Build-Depends: - debhelper (>= 9), - dh-systemd, + debhelper (>= 9.20160709) | dh-systemd, dh-virtualenv (>= 1.1), libsystemd-dev, libpq-dev, diff --git a/demo/start.sh b/demo/start.sh
index 83396e5c33..f6b5ea137f 100755 --- a/demo/start.sh +++ b/demo/start.sh
@@ -30,6 +30,8 @@ for port in 8080 8081 8082; do if ! grep -F "Customisation made by demo/start.sh" -q $DIR/etc/$port.config; then printf '\n\n# Customisation made by demo/start.sh\n' >> $DIR/etc/$port.config + echo "public_baseurl: http://localhost:$port/" >> $DIR/etc/$port.config + echo 'enable_registration: true' >> $DIR/etc/$port.config # Warning, this heredoc depends on the interaction of tabs and spaces. Please don't diff --git a/docker/Dockerfile b/docker/Dockerfile
index 27512f8600..afd896ffc1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile
@@ -11,7 +11,7 @@ # docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.6 . # -ARG PYTHON_VERSION=3.7 +ARG PYTHON_VERSION=3.8 ### ### Stage 0: builder @@ -36,7 +36,8 @@ RUN pip install --prefix="/install" --no-warn-script-location \ frozendict \ jaeger-client \ opentracing \ - prometheus-client \ + # Match the version constraints of Synapse + "prometheus_client>=0.4.0" \ psycopg2 \ pycparser \ pyrsistent \ diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv
index 619585d5fa..e529293803 100644 --- a/docker/Dockerfile-dhvirtualenv +++ b/docker/Dockerfile-dhvirtualenv
@@ -50,17 +50,22 @@ FROM ${distro} ARG distro="" ENV distro ${distro} +# Python < 3.7 assumes LANG="C" means ASCII-only and throws on printing unicode +# http://bugs.python.org/issue19846 +ENV LANG C.UTF-8 + # Install the build dependencies # # NB: keep this list in sync with the list of build-deps in debian/control # TODO: it would be nice to do that automatically. +# TODO: Remove the dh-systemd stanza after dropping support for Ubuntu xenial +# it's a transitional package on all other, more recent releases RUN apt-get update -qq -o Acquire::Languages=none \ && env DEBIAN_FRONTEND=noninteractive apt-get install \ -yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \ build-essential \ debhelper \ devscripts \ - dh-systemd \ libsystemd-dev \ lsb-release \ pkg-config \ @@ -69,7 +74,11 @@ RUN apt-get update -qq -o Acquire::Languages=none \ python3-setuptools \ python3-venv \ sqlite3 \ - libpq-dev + libpq-dev \ + xmlsec1 \ + && ( env DEBIAN_FRONTEND=noninteractive apt-get install \ + -yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \ + dh-systemd || true ) COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb / diff --git a/docker/README.md b/docker/README.md
index d0da34778e..c8f27b8566 100644 --- a/docker/README.md +++ b/docker/README.md
@@ -83,7 +83,7 @@ docker logs synapse If all is well, you should now be able to connect to http://localhost:8008 and see a confirmation message. -The following environment variables are supported in run mode: +The following environment variables are supported in `run` mode: * `SYNAPSE_CONFIG_DIR`: where additional config files are stored. Defaults to `/data`. @@ -94,6 +94,20 @@ The following environment variables are supported in run mode: * `UID`, `GID`: the user and group id to run Synapse as. Defaults to `991`, `991`. * `TZ`: the [timezone](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) the container will run with. Defaults to `UTC`. +For more complex setups (e.g. for workers) you can also pass your args directly to synapse using `run` mode. For example like this: + +``` +docker run -d --name synapse \ + --mount type=volume,src=synapse-data,dst=/data \ + -p 8008:8008 \ + matrixdotorg/synapse:latest run \ + -m synapse.app.generic_worker \ + --config-path=/data/homeserver.yaml \ + --config-path=/data/generic_worker.yaml +``` + +If you do not provide `-m`, the value of the `SYNAPSE_WORKER` environment variable is used. If you do not provide at least one `--config-path` or `-c`, the value of the `SYNAPSE_CONFIG_PATH` environment variable is used instead. + ## Generating an (admin) user After synapse is running, you may wish to create a user via `register_new_matrix_user`. diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml
index c1110f0f53..a808485c12 100644 --- a/docker/conf/homeserver.yaml +++ b/docker/conf/homeserver.yaml
@@ -90,7 +90,7 @@ federation_rc_concurrent: 3 media_store_path: "/data/media" uploads_path: "/data/uploads" -max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "10M" }}" +max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}" max_image_pixels: "32M" dynamic_thumbnails: false diff --git a/docker/start.py b/docker/start.py
index 9f08134158..0d2c590b88 100755 --- a/docker/start.py +++ b/docker/start.py
@@ -179,7 +179,7 @@ def run_generate_config(environ, ownership): def main(args, environ): - mode = args[1] if len(args) > 1 else None + mode = args[1] if len(args) > 1 else "run" desired_uid = int(environ.get("UID", "991")) desired_gid = int(environ.get("GID", "991")) synapse_worker = environ.get("SYNAPSE_WORKER", "synapse.app.homeserver") @@ -205,36 +205,47 @@ def main(args, environ): config_dir, config_path, environ, ownership ) - if mode is not None: + if mode != "run": error("Unknown execution mode '%s'" % (mode,)) - config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data") - config_path = environ.get("SYNAPSE_CONFIG_PATH", config_dir + "/homeserver.yaml") + args = args[2:] - if not os.path.exists(config_path): - if "SYNAPSE_SERVER_NAME" in environ: - error( - """\ + if "-m" not in args: + args = ["-m", synapse_worker] + args + + # if there are no config files passed to synapse, try adding the default file + if not any(p.startswith("--config-path") or p.startswith("-c") for p in args): + config_dir = environ.get("SYNAPSE_CONFIG_DIR", "/data") + config_path = environ.get( + "SYNAPSE_CONFIG_PATH", config_dir + "/homeserver.yaml" + ) + + if not os.path.exists(config_path): + if "SYNAPSE_SERVER_NAME" in environ: + error( + """\ Config file '%s' does not exist. The synapse docker image no longer supports generating a config file on-the-fly based on environment variables. You can migrate to a static config file by running with 'migrate_config'. See the README for more details. """ + % (config_path,) + ) + + error( + "Config file '%s' does not exist. You should either create a new " + "config file by running with the `generate` argument (and then edit " + "the resulting file before restarting) or specify the path to an " + "existing config file with the SYNAPSE_CONFIG_PATH variable." % (config_path,) ) - error( - "Config file '%s' does not exist. You should either create a new " - "config file by running with the `generate` argument (and then edit " - "the resulting file before restarting) or specify the path to an " - "existing config file with the SYNAPSE_CONFIG_PATH variable." - % (config_path,) - ) + args += ["--config-path", config_path] - log("Starting synapse with config file " + config_path) + log("Starting synapse with args " + " ".join(args)) - args = ["python", "-m", synapse_worker, "--config-path", config_path] + args = ["python"] + args if ownership is not None: args = ["gosu", ownership] + args os.execv("/usr/sbin/gosu", args) diff --git a/docs/admin_api/event_reports.md b/docs/admin_api/event_reports.md new file mode 100644
index 0000000000..0159098138 --- /dev/null +++ b/docs/admin_api/event_reports.md
@@ -0,0 +1,172 @@ +# Show reported events + +This API returns information about reported events. + +The api is: +``` +GET /_synapse/admin/v1/event_reports?from=0&limit=10 +``` +To use it, you will need to authenticate by providing an `access_token` for a +server admin: see [README.rst](README.rst). + +It returns a JSON body like the following: + +```json +{ + "event_reports": [ + { + "event_id": "$bNUFCwGzWca1meCGkjp-zwslF-GfVcXukvRLI1_FaVY", + "id": 2, + "reason": "foo", + "score": -100, + "received_ts": 1570897107409, + "canonical_alias": "#alias1:matrix.org", + "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "name": "Matrix HQ", + "sender": "@foobar:matrix.org", + "user_id": "@foo:matrix.org" + }, + { + "event_id": "$3IcdZsDaN_En-S1DF4EMCy3v4gNRKeOJs8W5qTOKj4I", + "id": 3, + "reason": "bar", + "score": -100, + "received_ts": 1598889612059, + "canonical_alias": "#alias2:matrix.org", + "room_id": "!eGvUQuTCkHGVwNMOjv:matrix.org", + "name": "Your room name here", + "sender": "@foobar:matrix.org", + "user_id": "@bar:matrix.org" + } + ], + "next_token": 2, + "total": 4 +} +``` + +To paginate, check for `next_token` and if present, call the endpoint again with `from` +set to the value of `next_token`. This will return a new page. + +If the endpoint does not return a `next_token` then there are no more reports to +paginate through. + +**URL parameters:** + +* `limit`: integer - Is optional but is used for pagination, denoting the maximum number + of items to return in this call. Defaults to `100`. +* `from`: integer - Is optional but used for pagination, denoting the offset in the + returned results. This should be treated as an opaque value and not explicitly set to + anything other than the return value of `next_token` from a previous call. Defaults to `0`. +* `dir`: string - Direction of event report order. Whether to fetch the most recent + first (`b`) or the oldest first (`f`). Defaults to `b`. +* `user_id`: string - Is optional and filters to only return users with user IDs that + contain this value. This is the user who reported the event and wrote the reason. +* `room_id`: string - Is optional and filters to only return rooms with room IDs that + contain this value. + +**Response** + +The following fields are returned in the JSON response body: + +* `id`: integer - ID of event report. +* `received_ts`: integer - The timestamp (in milliseconds since the unix epoch) when this + report was sent. +* `room_id`: string - The ID of the room in which the event being reported is located. +* `name`: string - The name of the room. +* `event_id`: string - The ID of the reported event. +* `user_id`: string - This is the user who reported the event and wrote the reason. +* `reason`: string - Comment made by the `user_id` in this report. May be blank. +* `score`: integer - Content is reported based upon a negative score, where -100 is + "most offensive" and 0 is "inoffensive". +* `sender`: string - This is the ID of the user who sent the original message/event that + was reported. +* `canonical_alias`: string - The canonical alias of the room. `null` if the room does not + have a canonical alias set. +* `next_token`: integer - Indication for pagination. See above. +* `total`: integer - Total number of event reports related to the query + (`user_id` and `room_id`). + +# Show details of a specific event report + +This API returns information about a specific event report. + +The api is: +``` +GET /_synapse/admin/v1/event_reports/<report_id> +``` +To use it, you will need to authenticate by providing an `access_token` for a +server admin: see [README.rst](README.rst). + +It returns a JSON body like the following: + +```jsonc +{ + "event_id": "$bNUFCwGzWca1meCGkjp-zwslF-GfVcXukvRLI1_FaVY", + "event_json": { + "auth_events": [ + "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M", + "$oggsNXxzPFRE3y53SUNd7nsj69-QzKv03a1RucHu-ws" + ], + "content": { + "body": "matrix.org: This Week in Matrix", + "format": "org.matrix.custom.html", + "formatted_body": "<strong>matrix.org</strong>:<br><a href=\"https://matrix.org/blog/\"><strong>This Week in Matrix</strong></a>", + "msgtype": "m.notice" + }, + "depth": 546, + "hashes": { + "sha256": "xK1//xnmvHJIOvbgXlkI8eEqdvoMmihVDJ9J4SNlsAw" + }, + "origin": "matrix.org", + "origin_server_ts": 1592291711430, + "prev_events": [ + "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M" + ], + "prev_state": [], + "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "sender": "@foobar:matrix.org", + "signatures": { + "matrix.org": { + "ed25519:a_JaEG": "cs+OUKW/iHx5pEidbWxh0UiNNHwe46Ai9LwNz+Ah16aWDNszVIe2gaAcVZfvNsBhakQTew51tlKmL2kspXk/Dg" + } + }, + "type": "m.room.message", + "unsigned": { + "age_ts": 1592291711430, + } + }, + "id": <report_id>, + "reason": "foo", + "score": -100, + "received_ts": 1570897107409, + "canonical_alias": "#alias1:matrix.org", + "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "name": "Matrix HQ", + "sender": "@foobar:matrix.org", + "user_id": "@foo:matrix.org" +} +``` + +**URL parameters:** + +* `report_id`: string - The ID of the event report. + +**Response** + +The following fields are returned in the JSON response body: + +* `id`: integer - ID of event report. +* `received_ts`: integer - The timestamp (in milliseconds since the unix epoch) when this + report was sent. +* `room_id`: string - The ID of the room in which the event being reported is located. +* `name`: string - The name of the room. +* `event_id`: string - The ID of the reported event. +* `user_id`: string - This is the user who reported the event and wrote the reason. +* `reason`: string - Comment made by the `user_id` in this report. May be blank. +* `score`: integer - Content is reported based upon a negative score, where -100 is + "most offensive" and 0 is "inoffensive". +* `sender`: string - This is the ID of the user who sent the original message/event that + was reported. +* `canonical_alias`: string - The canonical alias of the room. `null` if the room does not + have a canonical alias set. +* `event_json`: object - Details of the original event that was reported. diff --git a/docs/admin_api/event_reports.rst b/docs/admin_api/event_reports.rst deleted file mode 100644
index 461be01230..0000000000 --- a/docs/admin_api/event_reports.rst +++ /dev/null
@@ -1,129 +0,0 @@ -Show reported events -==================== - -This API returns information about reported events. - -The api is:: - - GET /_synapse/admin/v1/event_reports?from=0&limit=10 - -To use it, you will need to authenticate by providing an ``access_token`` for a -server admin: see `README.rst <README.rst>`_. - -It returns a JSON body like the following: - -.. code:: jsonc - - { - "event_reports": [ - { - "content": { - "reason": "foo", - "score": -100 - }, - "event_id": "$bNUFCwGzWca1meCGkjp-zwslF-GfVcXukvRLI1_FaVY", - "event_json": { - "auth_events": [ - "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M", - "$oggsNXxzPFRE3y53SUNd7nsj69-QzKv03a1RucHu-ws" - ], - "content": { - "body": "matrix.org: This Week in Matrix", - "format": "org.matrix.custom.html", - "formatted_body": "<strong>matrix.org</strong>:<br><a href=\"https://matrix.org/blog/\"><strong>This Week in Matrix</strong></a>", - "msgtype": "m.notice" - }, - "depth": 546, - "hashes": { - "sha256": "xK1//xnmvHJIOvbgXlkI8eEqdvoMmihVDJ9J4SNlsAw" - }, - "origin": "matrix.org", - "origin_server_ts": 1592291711430, - "prev_events": [ - "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M" - ], - "prev_state": [], - "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", - "sender": "@foobar:matrix.org", - "signatures": { - "matrix.org": { - "ed25519:a_JaEG": "cs+OUKW/iHx5pEidbWxh0UiNNHwe46Ai9LwNz+Ah16aWDNszVIe2gaAcVZfvNsBhakQTew51tlKmL2kspXk/Dg" - } - }, - "type": "m.room.message", - "unsigned": { - "age_ts": 1592291711430, - } - }, - "id": 2, - "reason": "foo", - "received_ts": 1570897107409, - "room_alias": "#alias1:matrix.org", - "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", - "sender": "@foobar:matrix.org", - "user_id": "@foo:matrix.org" - }, - { - "content": { - "reason": "bar", - "score": -100 - }, - "event_id": "$3IcdZsDaN_En-S1DF4EMCy3v4gNRKeOJs8W5qTOKj4I", - "event_json": { - // hidden items - // see above - }, - "id": 3, - "reason": "bar", - "received_ts": 1598889612059, - "room_alias": "#alias2:matrix.org", - "room_id": "!eGvUQuTCkHGVwNMOjv:matrix.org", - "sender": "@foobar:matrix.org", - "user_id": "@bar:matrix.org" - } - ], - "next_token": 2, - "total": 4 - } - -To paginate, check for ``next_token`` and if present, call the endpoint again -with ``from`` set to the value of ``next_token``. This will return a new page. - -If the endpoint does not return a ``next_token`` then there are no more -reports to paginate through. - -**URL parameters:** - -- ``limit``: integer - Is optional but is used for pagination, - denoting the maximum number of items to return in this call. Defaults to ``100``. -- ``from``: integer - Is optional but used for pagination, - denoting the offset in the returned results. This should be treated as an opaque value and - not explicitly set to anything other than the return value of ``next_token`` from a previous call. - Defaults to ``0``. -- ``dir``: string - Direction of event report order. Whether to fetch the most recent first (``b``) or the - oldest first (``f``). Defaults to ``b``. -- ``user_id``: string - Is optional and filters to only return users with user IDs that contain this value. - This is the user who reported the event and wrote the reason. -- ``room_id``: string - Is optional and filters to only return rooms with room IDs that contain this value. - -**Response** - -The following fields are returned in the JSON response body: - -- ``id``: integer - ID of event report. -- ``received_ts``: integer - The timestamp (in milliseconds since the unix epoch) when this report was sent. -- ``room_id``: string - The ID of the room in which the event being reported is located. -- ``event_id``: string - The ID of the reported event. -- ``user_id``: string - This is the user who reported the event and wrote the reason. -- ``reason``: string - Comment made by the ``user_id`` in this report. May be blank. -- ``content``: object - Content of reported event. - - - ``reason``: string - Comment made by the ``user_id`` in this report. May be blank. - - ``score``: integer - Content is reported based upon a negative score, where -100 is "most offensive" and 0 is "inoffensive". - -- ``sender``: string - This is the ID of the user who sent the original message/event that was reported. -- ``room_alias``: string - The alias of the room. ``null`` if the room does not have a canonical alias set. -- ``event_json``: object - Details of the original event that was reported. -- ``next_token``: integer - Indication for pagination. See above. -- ``total``: integer - Total number of event reports related to the query (``user_id`` and ``room_id``). - diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md
index 26948770d8..dfb8c5d751 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md
@@ -1,6 +1,18 @@ +# Contents +- [List all media in a room](#list-all-media-in-a-room) +- [Quarantine media](#quarantine-media) + * [Quarantining media by ID](#quarantining-media-by-id) + * [Quarantining media in a room](#quarantining-media-in-a-room) + * [Quarantining all media of a user](#quarantining-all-media-of-a-user) +- [Delete local media](#delete-local-media) + * [Delete a specific local media](#delete-a-specific-local-media) + * [Delete local media by date or size](#delete-local-media-by-date-or-size) +- [Purge Remote Media API](#purge-remote-media-api) + # List all media in a room This API gets a list of known media in a room. +However, it only shows media from unencrypted events or rooms. The API is: ``` @@ -10,16 +22,16 @@ To use it, you will need to authenticate by providing an `access_token` for a server admin: see [README.rst](README.rst). The API returns a JSON body like the following: -``` +```json { - "local": [ - "mxc://localhost/xwvutsrqponmlkjihgfedcba", - "mxc://localhost/abcdefghijklmnopqrstuvwx" - ], - "remote": [ - "mxc://matrix.org/xwvutsrqponmlkjihgfedcba", - "mxc://matrix.org/abcdefghijklmnopqrstuvwx" - ] + "local": [ + "mxc://localhost/xwvutsrqponmlkjihgfedcba", + "mxc://localhost/abcdefghijklmnopqrstuvwx" + ], + "remote": [ + "mxc://matrix.org/xwvutsrqponmlkjihgfedcba", + "mxc://matrix.org/abcdefghijklmnopqrstuvwx" + ] } ``` @@ -47,7 +59,7 @@ form of `abcdefg12345...`. Response: -``` +```json {} ``` @@ -67,14 +79,18 @@ Where `room_id` is in the form of `!roomid12345:example.org`. Response: -``` +```json { - "num_quarantined": 10 # The number of media items successfully quarantined + "num_quarantined": 10 } ``` +The following fields are returned in the JSON response body: + +* `num_quarantined`: integer - The number of media items successfully quarantined + Note that there is a legacy endpoint, `POST -/_synapse/admin/v1/quarantine_media/<room_id >`, that operates the same. +/_synapse/admin/v1/quarantine_media/<room_id>`, that operates the same. However, it is deprecated and may be removed in a future release. ## Quarantining all media of a user @@ -91,12 +107,132 @@ POST /_synapse/admin/v1/user/<user_id>/media/quarantine {} ``` -Where `user_id` is in the form of `@bob:example.org`. +URL Parameters + +* `user_id`: string - User ID in the form of `@bob:example.org` + +Response: + +```json +{ + "num_quarantined": 10 +} +``` + +The following fields are returned in the JSON response body: + +* `num_quarantined`: integer - The number of media items successfully quarantined + +# Delete local media +This API deletes the *local* media from the disk of your own server. +This includes any local thumbnails and copies of media downloaded from +remote homeservers. +This API will not affect media that has been uploaded to external +media repositories (e.g https://github.com/turt2live/matrix-media-repo/). +See also [Purge Remote Media API](#purge-remote-media-api). + +## Delete a specific local media +Delete a specific `media_id`. + +Request: + +``` +DELETE /_synapse/admin/v1/media/<server_name>/<media_id> + +{} +``` + +URL Parameters + +* `server_name`: string - The name of your local server (e.g `matrix.org`) +* `media_id`: string - The ID of the media (e.g `abcdefghijklmnopqrstuvwx`) + +Response: + +```json +{ + "deleted_media": [ + "abcdefghijklmnopqrstuvwx" + ], + "total": 1 +} +``` + +The following fields are returned in the JSON response body: + +* `deleted_media`: an array of strings - List of deleted `media_id` +* `total`: integer - Total number of deleted `media_id` + +## Delete local media by date or size + +Request: + +``` +POST /_synapse/admin/v1/media/<server_name>/delete?before_ts=<before_ts> + +{} +``` + +URL Parameters + +* `server_name`: string - The name of your local server (e.g `matrix.org`). +* `before_ts`: string representing a positive integer - Unix timestamp in ms. +Files that were last used before this timestamp will be deleted. It is the timestamp of +last access and not the timestamp creation. +* `size_gt`: Optional - string representing a positive integer - Size of the media in bytes. +Files that are larger will be deleted. Defaults to `0`. +* `keep_profiles`: Optional - string representing a boolean - Switch to also delete files +that are still used in image data (e.g user profile, room avatar). +If `false` these files will be deleted. Defaults to `true`. Response: +```json +{ + "deleted_media": [ + "abcdefghijklmnopqrstuvwx", + "abcdefghijklmnopqrstuvwz" + ], + "total": 2 +} +``` + +The following fields are returned in the JSON response body: + +* `deleted_media`: an array of strings - List of deleted `media_id` +* `total`: integer - Total number of deleted `media_id` + +# Purge Remote Media API + +The purge remote media API allows server admins to purge old cached remote media. + +The API is: + ``` +POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms> + +{} +``` + +URL Parameters + +* `unix_timestamp_in_ms`: string representing a positive integer - Unix timestamp in ms. +All cached media that was last accessed before this timestamp will be removed. + +Response: + +```json { - "num_quarantined": 10 # The number of media items successfully quarantined + "deleted": 10 } ``` + +The following fields are returned in the JSON response body: + +* `deleted`: integer - The number of media items successfully deleted + +To use it, you will need to authenticate by providing an `access_token` for a +server admin: see [README.rst](README.rst). + +If the user re-requests purged remote media, synapse will re-request the media +from the originating server. diff --git a/docs/admin_api/purge_remote_media.rst b/docs/admin_api/purge_remote_media.rst deleted file mode 100644
index 00cb6b0589..0000000000 --- a/docs/admin_api/purge_remote_media.rst +++ /dev/null
@@ -1,20 +0,0 @@ -Purge Remote Media API -====================== - -The purge remote media API allows server admins to purge old cached remote -media. - -The API is:: - - POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms> - - {} - -\... which will remove all cached media that was last accessed before -``<unix_timestamp_in_ms>``. - -To use it, you will need to authenticate by providing an ``access_token`` for a -server admin: see `README.rst <README.rst>`_. - -If the user re-requests purged remote media, synapse will re-request the media -from the originating server. diff --git a/docs/admin_api/purge_room.md b/docs/admin_api/purge_room.md
index ae01a543c6..54fea2db6d 100644 --- a/docs/admin_api/purge_room.md +++ b/docs/admin_api/purge_room.md
@@ -1,12 +1,13 @@ -Purge room API -============== +Deprecated: Purge room API +========================== + +**The old Purge room API is deprecated and will be removed in a future release. +See the new [Delete Room API](rooms.md#delete-room-api) for more details.** This API will remove all trace of a room from your database. All local users must have left the room before it can be removed. -See also: [Delete Room API](rooms.md#delete-room-api) - The API is: ``` diff --git a/docs/admin_api/register_api.rst b/docs/admin_api/register_api.rst
index 3a63109aa0..c3057b204b 100644 --- a/docs/admin_api/register_api.rst +++ b/docs/admin_api/register_api.rst
@@ -18,7 +18,8 @@ To fetch the nonce, you need to request one from the API:: Once you have the nonce, you can make a ``POST`` to the same URL with a JSON body containing the nonce, username, password, whether they are an admin -(optional, False by default), and a HMAC digest of the content. +(optional, False by default), and a HMAC digest of the content. Also you can +set the displayname (optional, ``username`` by default). As an example:: @@ -26,6 +27,7 @@ As an example:: > { "nonce": "thisisanonce", "username": "pepper_roni", + "displayname": "Pepper Roni", "password": "pizza", "admin": true, "mac": "mac_digest_here" diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index fa9b914fa7..9e560003a9 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md
@@ -1,3 +1,15 @@ +# Contents +- [List Room API](#list-room-api) + * [Parameters](#parameters) + * [Usage](#usage) +- [Room Details API](#room-details-api) +- [Room Members API](#room-members-api) +- [Delete Room API](#delete-room-api) + * [Parameters](#parameters-1) + * [Response](#response) + * [Undoing room shutdowns](#undoing-room-shutdowns) +- [Make Room Admin API](#make-room-admin-api) + # List Room API The List Room admin API allows server admins to get a list of rooms on their @@ -76,7 +88,7 @@ GET /_synapse/admin/v1/rooms Response: -``` +```jsonc { "rooms": [ { @@ -128,7 +140,7 @@ GET /_synapse/admin/v1/rooms?search_term=TWIM Response: -``` +```json { "rooms": [ { @@ -163,7 +175,7 @@ GET /_synapse/admin/v1/rooms?order_by=size Response: -``` +```jsonc { "rooms": [ { @@ -219,14 +231,14 @@ GET /_synapse/admin/v1/rooms?order_by=size&from=100 Response: -``` +```jsonc { "rooms": [ { "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", "name": "Music Theory", "canonical_alias": "#musictheory:matrix.org", - "joined_members": 127 + "joined_members": 127, "joined_local_members": 2, "version": "1", "creator": "@foo:matrix.org", @@ -243,7 +255,7 @@ Response: "room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk", "name": "weechat-matrix", "canonical_alias": "#weechat-matrix:termina.org.uk", - "joined_members": 137 + "joined_members": 137, "joined_local_members": 20, "version": "4", "creator": "@foo:termina.org.uk", @@ -265,12 +277,10 @@ Response: Once the `next_token` parameter is no longer present, we know we've reached the end of the list. -# DRAFT: Room Details API +# Room Details API The Room Details admin API allows server admins to get all details of a room. -This API is still a draft and details might change! - The following fields are possible in the JSON response body: * `room_id` - The ID of the room. @@ -280,6 +290,7 @@ The following fields are possible in the JSON response body: * `canonical_alias` - The canonical (main) alias address of the room. * `joined_members` - How many users are currently in the room. * `joined_local_members` - How many local users are currently in the room. +* `joined_local_devices` - How many local devices are currently in the room. * `version` - The version of the room as a string. * `creator` - The `user_id` of the room creator. * `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active. @@ -302,15 +313,16 @@ GET /_synapse/admin/v1/rooms/<room_id> Response: -``` +```json { "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", "name": "Music Theory", "avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV", "topic": "Theory, Composition, Notation, Analysis", "canonical_alias": "#musictheory:matrix.org", - "joined_members": 127 + "joined_members": 127, "joined_local_members": 2, + "joined_local_devices": 2, "version": "1", "creator": "@foo:matrix.org", "encryption": null, @@ -344,13 +356,13 @@ GET /_synapse/admin/v1/rooms/<room_id>/members Response: -``` +```json { "members": [ "@foo:matrix.org", "@bar:matrix.org", - "@foobar:matrix.org - ], + "@foobar:matrix.org" + ], "total": 3 } ``` @@ -359,8 +371,6 @@ Response: The Delete Room admin API allows server admins to remove rooms from server and block these rooms. -It is a combination and improvement of "[Shutdown room](shutdown_room.md)" -and "[Purge room](purge_room.md)" API. Shuts down a room. Moves all local users and room aliases automatically to a new room if `new_room_user_id` is set. Otherwise local users only @@ -384,7 +394,7 @@ the new room. Users on other servers will be unaffected. The API is: -```json +``` POST /_synapse/admin/v1/rooms/<room_id>/delete ``` @@ -441,6 +451,10 @@ The following JSON body parameters are available: future attempts to join the room. Defaults to `false`. * `purge` - Optional. If set to `true`, it will remove all traces of the room from your database. Defaults to `true`. +* `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it + will force a purge to go ahead even if there are local users still in the room. Do not + use this unless a regular `purge` operation fails, as it could leave those users' + clients in a confused state. The JSON body must not be empty. The body must be at least `{}`. @@ -453,3 +467,47 @@ The following fields are returned in the JSON response body: * `local_aliases` - An array of strings representing the local aliases that were migrated from the old room to the new. * `new_room_id` - A string representing the room ID of the new room. + + +## Undoing room shutdowns + +*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level, +the structure can and does change without notice. + +First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it +never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible +to recover at all: + +* If the room was invite-only, your users will need to be re-invited. +* If the room no longer has any members at all, it'll be impossible to rejoin. +* The first user to rejoin will have to do so via an alias on a different server. + +With all that being said, if you still want to try and recover the room: + +1. For safety reasons, shut down Synapse. +2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';` + * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`. + * The room ID is the same one supplied to the shutdown room API, not the Content Violation room. +3. Restart Synapse. + +You will have to manually handle, if you so choose, the following: + +* Aliases that would have been redirected to the Content Violation room. +* Users that would have been booted from the room (and will have been force-joined to the Content Violation room). +* Removal of the Content Violation room if desired. + + +# Make Room Admin API + +Grants another user the highest power available to a local user who is in the room. +If the user is not in the room, and it is not publicly joinable, then invite the user. + +By default the server admin (the caller) is granted power, but another user can +optionally be specified, e.g.: + +``` + POST /_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin + { + "user_id": "@foo:example.com" + } +``` diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md
index 9b1cb1c184..856a629487 100644 --- a/docs/admin_api/shutdown_room.md +++ b/docs/admin_api/shutdown_room.md
@@ -1,4 +1,7 @@ -# Shutdown room API +# Deprecated: Shutdown room API + +**The old Shutdown room API is deprecated and will be removed in a future release. +See the new [Delete Room API](rooms.md#delete-room-api) for more details.** Shuts down a room, preventing new joins and moves local users and room aliases automatically to a new room. The new room will be created with the user specified by the @@ -10,8 +13,6 @@ disallow any further invites or joins. The local server will only have the power to move local user and room aliases to the new room. Users on other servers will be unaffected. -See also: [Delete Room API](rooms.md#delete-room-api) - ## API You will need to authenticate with an access token for an admin user. diff --git a/docs/admin_api/statistics.md b/docs/admin_api/statistics.md new file mode 100644
index 0000000000..d398a120fb --- /dev/null +++ b/docs/admin_api/statistics.md
@@ -0,0 +1,83 @@ +# Users' media usage statistics + +Returns information about all local media usage of users. Gives the +possibility to filter them by time and user. + +The API is: + +``` +GET /_synapse/admin/v1/statistics/users/media +``` + +To use it, you will need to authenticate by providing an `access_token` +for a server admin: see [README.rst](README.rst). + +A response body like the following is returned: + +```json +{ + "users": [ + { + "displayname": "foo_user_0", + "media_count": 2, + "media_length": 134, + "user_id": "@foo_user_0:test" + }, + { + "displayname": "foo_user_1", + "media_count": 2, + "media_length": 134, + "user_id": "@foo_user_1:test" + } + ], + "next_token": 3, + "total": 10 +} +``` + +To paginate, check for `next_token` and if present, call the endpoint +again with `from` set to the value of `next_token`. This will return a new page. + +If the endpoint does not return a `next_token` then there are no more +reports to paginate through. + +**Parameters** + +The following parameters should be set in the URL: + +* `limit`: string representing a positive integer - Is optional but is + used for pagination, denoting the maximum number of items to return + in this call. Defaults to `100`. +* `from`: string representing a positive integer - Is optional but used for pagination, + denoting the offset in the returned results. This should be treated as an opaque value + and not explicitly set to anything other than the return value of `next_token` from a + previous call. Defaults to `0`. +* `order_by` - string - The method in which to sort the returned list of users. Valid values are: + - `user_id` - Users are ordered alphabetically by `user_id`. This is the default. + - `displayname` - Users are ordered alphabetically by `displayname`. + - `media_length` - Users are ordered by the total size of uploaded media in bytes. + Smallest to largest. + - `media_count` - Users are ordered by number of uploaded media. Smallest to largest. +* `from_ts` - string representing a positive integer - Considers only + files created at this timestamp or later. Unix timestamp in ms. +* `until_ts` - string representing a positive integer - Considers only + files created at this timestamp or earlier. Unix timestamp in ms. +* `search_term` - string - Filter users by their user ID localpart **or** displayname. + The search term can be found in any part of the string. + Defaults to no filtering. +* `dir` - string - Direction of order. Either `f` for forwards or `b` for backwards. + Setting this value to `b` will reverse the above sort order. Defaults to `f`. + + +**Response** + +The following fields are returned in the JSON response body: + +* `users` - An array of objects, each containing information + about the user and their local media. Objects contain the following fields: + - `displayname` - string - Displayname of this user. + - `media_count` - integer - Number of uploaded media by this user. + - `media_length` - integer - Size of uploaded media in bytes by this user. + - `user_id` - string - Fully-qualified user ID (ex. `@user:server.com`). +* `next_token` - integer - Opaque value used for pagination. See above. +* `total` - integer - Total number of users after filtering. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 7ca902faba..e4d6f8203b 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst
@@ -30,7 +30,12 @@ It returns a JSON body like the following: ], "avatar_url": "<avatar_url>", "admin": false, - "deactivated": false + "deactivated": false, + "password_hash": "$2b$12$p9B4GkqYdRTPGD", + "creation_ts": 1560432506, + "appservice_id": null, + "consent_server_notice_sent": null, + "consent_version": null } URL parameters: @@ -139,7 +144,6 @@ A JSON body is returned with the following shape: "users": [ { "name": "<user_id1>", - "password_hash": "<password_hash1>", "is_guest": 0, "admin": 0, "user_type": null, @@ -148,7 +152,6 @@ A JSON body is returned with the following shape: "avatar_url": null }, { "name": "<user_id2>", - "password_hash": "<password_hash2>", "is_guest": 0, "admin": 1, "user_type": null, @@ -176,6 +179,13 @@ The api is:: GET /_synapse/admin/v1/whois/<user_id> +and:: + + GET /_matrix/client/r0/admin/whois/<userId> + +See also: `Client Server API Whois +<https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid>`_ + To use it, you will need to authenticate by providing an ``access_token`` for a server admin: see `README.rst <README.rst>`_. @@ -254,7 +264,7 @@ with a body of: { "new_password": "<secret>", - "logout_devices": true, + "logout_devices": true } To use it, you will need to authenticate by providing an ``access_token`` for a @@ -341,6 +351,124 @@ The following fields are returned in the JSON response body: - ``total`` - Number of rooms. +List media of an user +================================ +Gets a list of all local media that a specific ``user_id`` has created. +The response is ordered by creation date descending and media ID descending. +The newest media is on top. + +The API is:: + + GET /_synapse/admin/v1/users/<user_id>/media + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +A response body like the following is returned: + +.. code:: json + + { + "media": [ + { + "created_ts": 100400, + "last_access_ts": null, + "media_id": "qXhyRzulkwLsNHTbpHreuEgo", + "media_length": 67, + "media_type": "image/png", + "quarantined_by": null, + "safe_from_quarantine": false, + "upload_name": "test1.png" + }, + { + "created_ts": 200400, + "last_access_ts": null, + "media_id": "FHfiSnzoINDatrXHQIXBtahw", + "media_length": 67, + "media_type": "image/png", + "quarantined_by": null, + "safe_from_quarantine": false, + "upload_name": "test2.png" + } + ], + "next_token": 3, + "total": 2 + } + +To paginate, check for ``next_token`` and if present, call the endpoint again +with ``from`` set to the value of ``next_token``. This will return a new page. + +If the endpoint does not return a ``next_token`` then there are no more +reports to paginate through. + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - string - fully qualified: for example, ``@user:server.com``. +- ``limit``: string representing a positive integer - Is optional but is used for pagination, + denoting the maximum number of items to return in this call. Defaults to ``100``. +- ``from``: string representing a positive integer - Is optional but used for pagination, + denoting the offset in the returned results. This should be treated as an opaque value and + not explicitly set to anything other than the return value of ``next_token`` from a previous call. + Defaults to ``0``. + +**Response** + +The following fields are returned in the JSON response body: + +- ``media`` - An array of objects, each containing information about a media. + Media objects contain the following fields: + + - ``created_ts`` - integer - Timestamp when the content was uploaded in ms. + - ``last_access_ts`` - integer - Timestamp when the content was last accessed in ms. + - ``media_id`` - string - The id used to refer to the media. + - ``media_length`` - integer - Length of the media in bytes. + - ``media_type`` - string - The MIME-type of the media. + - ``quarantined_by`` - string - The user ID that initiated the quarantine request + for this media. + + - ``safe_from_quarantine`` - bool - Status if this media is safe from quarantining. + - ``upload_name`` - string - The name the media was uploaded with. + +- ``next_token``: integer - Indication for pagination. See above. +- ``total`` - integer - Total number of media. + +Login as a user +=============== + +Get an access token that can be used to authenticate as that user. Useful for +when admins wish to do actions on behalf of a user. + +The API is:: + + POST /_synapse/admin/v1/users/<user_id>/login + {} + +An optional ``valid_until_ms`` field can be specified in the request body as an +integer timestamp that specifies when the token should expire. By default tokens +do not expire. + +A response body like the following is returned: + +.. code:: json + + { + "access_token": "<opaque_access_token_string>" + } + + +This API does *not* generate a new device for the user, and so will not appear +their ``/devices`` list, and in general the target user should not be able to +tell they have been logged in as. + +To expire the token call the standard ``/logout`` API with the token. + +Note: The token will expire if the *admin* user calls ``/logout/all`` from any +of their devices, but the token will *not* expire if the target user does the +same. + + User devices ============ @@ -375,7 +503,8 @@ A response body like the following is returned: "last_seen_ts": 1474491775025, "user_id": "<user_id>" } - ] + ], + "total": 2 } **Parameters** @@ -400,6 +529,8 @@ The following fields are returned in the JSON response body: devices was last seen. (May be a few minutes out of date, for efficiency reasons). - ``user_id`` - Owner of device. +- ``total`` - Total number of user's devices. + Delete multiple devices ------------------ Deletes the given devices for a specific ``user_id``, and invalidates @@ -525,3 +656,82 @@ The following parameters should be set in the URL: - ``user_id`` - fully qualified: for example, ``@user:server.com``. - ``device_id`` - The device to delete. + +List all pushers +================ +Gets information about all pushers for a specific ``user_id``. + +The API is:: + + GET /_synapse/admin/v1/users/<user_id>/pushers + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +A response body like the following is returned: + +.. code:: json + + { + "pushers": [ + { + "app_display_name":"HTTP Push Notifications", + "app_id":"m.http", + "data": { + "url":"example.com" + }, + "device_display_name":"pushy push", + "kind":"http", + "lang":"None", + "profile_tag":"", + "pushkey":"a@example.com" + } + ], + "total": 1 + } + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. + +**Response** + +The following fields are returned in the JSON response body: + +- ``pushers`` - An array containing the current pushers for the user + + - ``app_display_name`` - string - A string that will allow the user to identify + what application owns this pusher. + + - ``app_id`` - string - This is a reverse-DNS style identifier for the application. + Max length, 64 chars. + + - ``data`` - A dictionary of information for the pusher implementation itself. + + - ``url`` - string - Required if ``kind`` is ``http``. The URL to use to send + notifications to. + + - ``format`` - string - The format to use when sending notifications to the + Push Gateway. + + - ``device_display_name`` - string - A string that will allow the user to identify + what device owns this pusher. + + - ``profile_tag`` - string - This string determines which set of device specific rules + this pusher executes. + + - ``kind`` - string - The kind of pusher. "http" is a pusher that sends HTTP pokes. + - ``lang`` - string - The preferred language for receiving notifications + (e.g. 'en' or 'en-US') + + - ``profile_tag`` - string - This string determines which set of device specific rules + this pusher executes. + + - ``pushkey`` - string - This is a unique identifier for this pusher. + Max length, 512 bytes. + +- ``total`` - integer - Number of pushers. + +See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_ diff --git a/docs/code_style.md b/docs/code_style.md
index 6ef6f80290..f6c825d7d4 100644 --- a/docs/code_style.md +++ b/docs/code_style.md
@@ -64,8 +64,6 @@ save as it takes a while and is very resource intensive. - Use underscores for functions and variables. - **Docstrings**: should follow the [google code style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings). - This is so that we can generate documentation with - [sphinx](http://sphinxcontrib-napoleon.readthedocs.org/en/latest/). See the [examples](http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) in the sphinx documentation. diff --git a/docs/dev/cas.md b/docs/dev/cas.md
index f8d02cc82c..592b2d8d4f 100644 --- a/docs/dev/cas.md +++ b/docs/dev/cas.md
@@ -31,7 +31,7 @@ easy to run CAS implementation built on top of Django. You should now have a Django project configured to serve CAS authentication with a single user created. -## Configure Synapse (and Riot) to use CAS +## Configure Synapse (and Element) to use CAS 1. Modify your `homeserver.yaml` to enable CAS and point it to your locally running Django test server: @@ -51,9 +51,9 @@ and that the CAS server is on port 8000, both on localhost. ## Testing the configuration -Then in Riot: +Then in Element: -1. Visit the login page with a Riot pointing at your homeserver. +1. Visit the login page with a Element pointing at your homeserver. 2. Click the Single Sign-On button. 3. Login using the credentials created with `createsuperuser`. 4. You should be logged in. diff --git a/docs/manhole.md b/docs/manhole.md
index 7375f5ad46..37d1d7823c 100644 --- a/docs/manhole.md +++ b/docs/manhole.md
@@ -5,22 +5,54 @@ The "manhole" allows server administrators to access a Python shell on a running Synapse installation. This is a very powerful mechanism for administration and debugging. +**_Security Warning_** + +Note that this will give administrative access to synapse to **all users** with +shell access to the server. It should therefore **not** be enabled in +environments where untrusted users have shell access. + +*** + To enable it, first uncomment the `manhole` listener configuration in -`homeserver.yaml`: +`homeserver.yaml`. The configuration is slightly different if you're using docker. + +#### Docker config + +If you are using Docker, set `bind_addresses` to `['0.0.0.0']` as shown: ```yaml listeners: - port: 9000 - bind_addresses: ['::1', '127.0.0.1'] + bind_addresses: ['0.0.0.0'] type: manhole ``` -(`bind_addresses` in the above is important: it ensures that access to the -manhole is only possible for local users). +When using `docker run` to start the server, you will then need to change the command to the following to include the +`manhole` port forwarding. The `-p 127.0.0.1:9000:9000` below is important: it +ensures that access to the `manhole` is only possible for local users. -Note that this will give administrative access to synapse to **all users** with -shell access to the server. It should therefore **not** be enabled in -environments where untrusted users have shell access. +```bash +docker run -d --name synapse \ + --mount type=volume,src=synapse-data,dst=/data \ + -p 8008:8008 \ + -p 127.0.0.1:9000:9000 \ + matrixdotorg/synapse:latest +``` + +#### Native config + +If you are not using docker, set `bind_addresses` to `['::1', '127.0.0.1']` as shown. +The `bind_addresses` in the example below is important: it ensures that access to the +`manhole` is only possible for local users). + +```yaml +listeners: + - port: 9000 + bind_addresses: ['::1', '127.0.0.1'] + type: manhole +``` + +#### Accessing synapse manhole Then restart synapse, and point an ssh client at port 9000 on localhost, using the username `matrix`: @@ -35,9 +67,12 @@ This gives a Python REPL in which `hs` gives access to the `synapse.server.HomeServer` object - which in turn gives access to many other parts of the process. +Note that any call which returns a coroutine will need to be wrapped in `ensureDeferred`. + As a simple example, retrieving an event from the database: -``` ->>> hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org') +```pycon +>>> from twisted.internet import defer +>>> defer.ensureDeferred(hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org')) <Deferred at 0x7ff253fc6998 current result: <FrozenEvent event_id='$1416420717069yeQaw:matrix.org', type='m.room.create', state_key=''>> ``` diff --git a/docs/message_retention_policies.md b/docs/message_retention_policies.md
index 1dd60bdad9..75d2028e17 100644 --- a/docs/message_retention_policies.md +++ b/docs/message_retention_policies.md
@@ -136,24 +136,34 @@ the server's database. ### Lifetime limits -**Note: this feature is mainly useful within a closed federation or on -servers that don't federate, because there currently is no way to -enforce these limits in an open federation.** - -Server admins can restrict the values their local users are allowed to -use for both `min_lifetime` and `max_lifetime`. These limits can be -defined as such in the `retention` section of the configuration file: +Server admins can set limits on the values of `max_lifetime` to use when +purging old events in a room. These limits can be defined as such in the +`retention` section of the configuration file: ```yaml allowed_lifetime_min: 1d allowed_lifetime_max: 1y ``` -Here, `allowed_lifetime_min` is the lowest value a local user can set -for both `min_lifetime` and `max_lifetime`, and `allowed_lifetime_max` -is the highest value. Both parameters are optional (e.g. setting -`allowed_lifetime_min` but not `allowed_lifetime_max` only enforces a -minimum and no maximum). +The limits are considered when running purge jobs. If necessary, the +effective value of `max_lifetime` will be brought between +`allowed_lifetime_min` and `allowed_lifetime_max` (inclusive). +This means that, if the value of `max_lifetime` defined in the room's state +is lower than `allowed_lifetime_min`, the value of `allowed_lifetime_min` +will be used instead. Likewise, if the value of `max_lifetime` is higher +than `allowed_lifetime_max`, the value of `allowed_lifetime_max` will be +used instead. + +In the example above, we ensure Synapse never deletes events that are less +than one day old, and that it always deletes events that are over a year +old. + +If a default policy is set, and its `max_lifetime` value is lower than +`allowed_lifetime_min` or higher than `allowed_lifetime_max`, the same +process applies. + +Both parameters are optional; if one is omitted Synapse won't use it to +adjust the effective value of `max_lifetime`. Like other settings in this section, these parameters can be expressed either as a duration or as a number of milliseconds. diff --git a/docs/metrics-howto.md b/docs/metrics-howto.md
index b386ec91c1..6b84153274 100644 --- a/docs/metrics-howto.md +++ b/docs/metrics-howto.md
@@ -13,10 +13,12 @@ can be enabled by adding the \"metrics\" resource to the existing listener as such: - resources: - - names: - - client - - metrics + ```yaml + resources: + - names: + - client + - metrics + ``` This provides a simple way of adding metrics to your Synapse installation, and serves under `/_synapse/metrics`. If you do not @@ -31,11 +33,13 @@ Add a new listener to homeserver.yaml: - listeners: - - type: metrics - port: 9000 - bind_addresses: - - '0.0.0.0' + ```yaml + listeners: + - type: metrics + port: 9000 + bind_addresses: + - '0.0.0.0' + ``` For both options, you will need to ensure that `enable_metrics` is set to `True`. @@ -47,10 +51,13 @@ It needs to set the `metrics_path` to a non-default value (under `scrape_configs`): - - job_name: "synapse" - metrics_path: "/_synapse/metrics" - static_configs: - - targets: ["my.server.here:port"] + ```yaml + - job_name: "synapse" + scrape_interval: 15s + metrics_path: "/_synapse/metrics" + static_configs: + - targets: ["my.server.here:port"] + ``` where `my.server.here` is the IP address of Synapse, and `port` is the listener port configured with the `metrics` resource. @@ -60,6 +67,9 @@ 1. Restart Prometheus. +1. Consider using the [grafana dashboard](https://github.com/matrix-org/synapse/tree/master/contrib/grafana/) + and required [recording rules](https://github.com/matrix-org/synapse/tree/master/contrib/prometheus/) + ## Monitoring workers To monitor a Synapse installation using @@ -74,9 +84,9 @@ To allow collecting metrics from a worker, you need to add a under `worker_listeners`: ```yaml - - type: metrics - bind_address: '' - port: 9101 + - type: metrics + bind_address: '' + port: 9101 ``` The `bind_address` and `port` parameters should be set so that @@ -85,6 +95,38 @@ don't clash with an existing worker. With this example, the worker's metrics would then be available on `http://127.0.0.1:9101`. +Example Prometheus target for Synapse with workers: + +```yaml + - job_name: "synapse" + scrape_interval: 15s + metrics_path: "/_synapse/metrics" + static_configs: + - targets: ["my.server.here:port"] + labels: + instance: "my.server" + job: "master" + index: 1 + - targets: ["my.workerserver.here:port"] + labels: + instance: "my.server" + job: "generic_worker" + index: 1 + - targets: ["my.workerserver.here:port"] + labels: + instance: "my.server" + job: "generic_worker" + index: 2 + - targets: ["my.workerserver.here:port"] + labels: + instance: "my.server" + job: "media_repository" + index: 1 +``` + +Labels (`instance`, `job`, `index`) can be defined as anything. +The labels are used to group graphs in grafana. + ## Renaming of metrics & deprecation of old names in 1.2 Synapse 1.2 updates the Prometheus metrics to match the naming diff --git a/docs/openid.md b/docs/openid.md
index 70b37f858b..da391f74aa 100644 --- a/docs/openid.md +++ b/docs/openid.md
@@ -37,7 +37,7 @@ as follows: provided by `matrix.org` so no further action is needed. * If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip - install synapse[oidc]` to install the necessary dependencies. + install matrix-synapse[oidc]` to install the necessary dependencies. * For other installation mechanisms, see the documentation provided by the maintainer. @@ -52,14 +52,39 @@ specific providers. Here are a few configs for providers that should work with Synapse. +### Microsoft Azure Active Directory +Azure AD can act as an OpenID Connect Provider. Register a new application under +*App registrations* in the Azure AD management console. The RedirectURI for your +application should point to your matrix server: `[synapse public baseurl]/_synapse/oidc/callback` + +Go to *Certificates & secrets* and register a new client secret. Make note of your +Directory (tenant) ID as it will be used in the Azure links. +Edit your Synapse config file and change the `oidc_config` section: + +```yaml +oidc_config: + enabled: true + issuer: "https://login.microsoftonline.com/<tenant id>/v2.0" + client_id: "<client id>" + client_secret: "<client secret>" + scopes: ["openid", "profile"] + authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize" + token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token" + userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" + + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username.split('@')[0] }}" + display_name_template: "{{ user.name }}" +``` + ### [Dex][dex-idp] [Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider. Although it is designed to help building a full-blown provider with an external database, it can be configured with static passwords in a config file. -Follow the [Getting Started -guide](https://github.com/dexidp/dex/blob/master/Documentation/getting-started.md) +Follow the [Getting Started guide](https://dexidp.io/docs/getting-started/) to install Dex. Edit `examples/config-dev.yaml` config file from the Dex repo to add a client: @@ -73,7 +98,7 @@ staticClients: name: 'Synapse' ``` -Run with `dex serve examples/config-dex.yaml`. +Run with `dex serve examples/config-dev.yaml`. Synapse config: @@ -180,7 +205,7 @@ GitHub is a bit special as it is not an OpenID Connect compliant provider, but just a regular OAuth2 provider. The [`/user` API endpoint](https://developer.github.com/v3/users/#get-the-authenticated-user) -can be used to retrieve information on the authenticated user. As the Synaspse +can be used to retrieve information on the authenticated user. As the Synapse login mechanism needs an attribute to uniquely identify users, and that endpoint does not return a `sub` property, an alternative `subject_claim` has to be set. @@ -238,13 +263,36 @@ Synapse config: ```yaml oidc_config: - enabled: true - issuer: "https://id.twitch.tv/oauth2/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - client_auth_method: "client_secret_post" - user_mapping_provider: - config: - localpart_template: '{{ user.preferred_username }}' - display_name_template: '{{ user.name }}' + enabled: true + issuer: "https://id.twitch.tv/oauth2/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" +``` + +### GitLab + +1. Create a [new application](https://gitlab.com/profile/applications). +2. Add the `read_user` and `openid` scopes. +3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback` + +Synapse config: + +```yaml +oidc_config: + enabled: true + issuer: "https://gitlab.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + scopes: ["openid", "read_user"] + user_profile_method: "userinfo_endpoint" + user_mapping_provider: + config: + localpart_template: '{{ user.nickname }}' + display_name_template: '{{ user.name }}' ``` diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md
index 7d98d9f255..d2cdb9b2f4 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md
@@ -26,6 +26,7 @@ Password auth provider classes must provide the following methods: It should perform any appropriate sanity checks on the provided configuration, and return an object which is then passed into + `__init__`. This method should have the `@staticmethod` decoration. diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index 46d8f35771..c7020f2df3 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md
@@ -54,7 +54,7 @@ server { proxy_set_header X-Forwarded-For $remote_addr; # Nginx by default only allows file uploads up to 1M in size # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml - client_max_body_size 10M; + client_max_body_size 50M; } } ``` diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 8a3206e845..dd981609ac 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml
@@ -119,7 +119,7 @@ pid_file: DATADIR/homeserver.pid # For example, for room version 1, default_room_version should be set # to "1". # -#default_room_version: "5" +#default_room_version: "6" # The GC threshold parameters to pass to `gc.set_threshold`, if defined # @@ -144,6 +144,47 @@ pid_file: DATADIR/homeserver.pid # #enable_search: false +# Prevent outgoing requests from being sent to the following blacklisted IP address +# CIDR ranges. If this option is not specified then it defaults to private IP +# address ranges (see the example below). +# +# The blacklist applies to the outbound requests for federation, identity servers, +# push servers, and for checking key validity for third-party invite events. +# +# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly +# listed here, since they correspond to unroutable addresses.) +# +# This option replaces federation_ip_range_blacklist in Synapse v1.25.0. +# +#ip_range_blacklist: +# - '127.0.0.0/8' +# - '10.0.0.0/8' +# - '172.16.0.0/12' +# - '192.168.0.0/16' +# - '100.64.0.0/10' +# - '192.0.0.0/24' +# - '169.254.0.0/16' +# - '198.18.0.0/15' +# - '192.0.2.0/24' +# - '198.51.100.0/24' +# - '203.0.113.0/24' +# - '224.0.0.0/4' +# - '::1/128' +# - 'fe80::/10' +# - 'fc00::/7' + +# List of IP address CIDR ranges that should be allowed for federation, +# identity servers, push servers, and for checking key validity for +# third-party invite events. This is useful for specifying exceptions to +# wide-ranging blacklisted target IP ranges - e.g. for communication with +# a push server only visible in your network. +# +# This whitelist overrides ip_range_blacklist and defaults to an empty +# list. +# +#ip_range_whitelist: +# - '192.168.1.1' + # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -642,27 +683,6 @@ acme: # - nyc.example.com # - syd.example.com -# Prevent federation requests from being sent to the following -# blacklist IP address CIDR ranges. If this option is not specified, or -# specified with an empty list, no ip range blacklist will be enforced. -# -# As of Synapse v1.4.0 this option also affects any outbound requests to identity -# servers provided by user input. -# -# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly -# listed here, since they correspond to unroutable addresses.) -# -federation_ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' - # Report prometheus metrics on the age of PDUs being sent to and received from # the following domains. This can be used to give an idea of "delay" on inbound # and outbound federation, though be aware that any delay can be due to problems @@ -893,7 +913,7 @@ media_store_path: "DATADIR/media_store" # The largest allowed upload size in bytes # -#max_upload_size: 10M +#max_upload_size: 50M # Maximum number of pixels that will be thumbnailed # @@ -953,9 +973,15 @@ media_store_path: "DATADIR/media_store" # - '172.16.0.0/12' # - '192.168.0.0/16' # - '100.64.0.0/10' +# - '192.0.0.0/24' # - '169.254.0.0/16' +# - '198.18.0.0/15' +# - '192.0.2.0/24' +# - '198.51.100.0/24' +# - '203.0.113.0/24' +# - '224.0.0.0/4' # - '::1/128' -# - 'fe80::/64' +# - 'fe80::/10' # - 'fc00::/7' # List of IP address CIDR ranges that the URL preview spider is allowed @@ -1230,8 +1256,9 @@ account_validity: # email will be globally disabled. # # Additionally, if `msisdn` is not set, registration and password resets via msisdn -# will be disabled regardless. This is due to Synapse currently not supporting any -# method of sending SMS messages on its own. +# will be disabled regardless, and users will not be able to associate an msisdn +# identifier to their account. This is due to Synapse currently not supporting +# any method of sending SMS messages on its own. # # To enable using an identity server for operations regarding a particular third-party # identifier type, set the value to the URL of that identity server as shown in the @@ -1505,10 +1532,8 @@ trusted_key_servers: ## Single sign-on integration ## -# Enable SAML2 for registration and login. Uses pysaml2. -# -# At least one of `sp_config` or `config_path` must be set in this section to -# enable SAML login. +# The following settings can be used to make Synapse use a single sign-on +# provider for authentication, instead of its internal password database. # # You will probably also want to set the following options to `false` to # disable the regular login/registration flows: @@ -1517,6 +1542,11 @@ trusted_key_servers: # # You will also want to investigate the settings under the "sso" configuration # section below. + +# Enable SAML2 for registration and login. Uses pysaml2. +# +# At least one of `sp_config` or `config_path` must be set in this section to +# enable SAML login. # # Once SAML support is enabled, a metadata file will be exposed at # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to @@ -1532,40 +1562,70 @@ saml2_config: # so it is not normally necessary to specify them unless you need to # override them. # - #sp_config: - # # point this to the IdP's metadata. You can use either a local file or - # # (preferably) a URL. - # metadata: - # #local: ["saml2/idp.xml"] - # remote: - # - url: https://our_idp/metadata.xml - # - # # By default, the user has to go to our login page first. If you'd like - # # to allow IdP-initiated login, set 'allow_unsolicited: true' in a - # # 'service.sp' section: - # # - # #service: - # # sp: - # # allow_unsolicited: true - # - # # The examples below are just used to generate our metadata xml, and you - # # may well not need them, depending on your setup. Alternatively you - # # may need a whole lot more detail - see the pysaml2 docs! - # - # description: ["My awesome SP", "en"] - # name: ["Test SP", "en"] - # - # organization: - # name: Example com - # display_name: - # - ["Example co", "en"] - # url: "http://example.com" - # - # contact_person: - # - given_name: Bob - # sur_name: "the Sysadmin" - # email_address": ["admin@example.com"] - # contact_type": technical + sp_config: + # Point this to the IdP's metadata. You must provide either a local + # file via the `local` attribute or (preferably) a URL via the + # `remote` attribute. + # + #metadata: + # local: ["saml2/idp.xml"] + # remote: + # - url: https://our_idp/metadata.xml + + # Allowed clock difference in seconds between the homeserver and IdP. + # + # Uncomment the below to increase the accepted time difference from 0 to 3 seconds. + # + #accepted_time_diff: 3 + + # By default, the user has to go to our login page first. If you'd like + # to allow IdP-initiated login, set 'allow_unsolicited: true' in a + # 'service.sp' section: + # + #service: + # sp: + # allow_unsolicited: true + + # The examples below are just used to generate our metadata xml, and you + # may well not need them, depending on your setup. Alternatively you + # may need a whole lot more detail - see the pysaml2 docs! + + #description: ["My awesome SP", "en"] + #name: ["Test SP", "en"] + + #ui_info: + # display_name: + # - lang: en + # text: "Display Name is the descriptive name of your service." + # description: + # - lang: en + # text: "Description should be a short paragraph explaining the purpose of the service." + # information_url: + # - lang: en + # text: "https://example.com/terms-of-service" + # privacy_statement_url: + # - lang: en + # text: "https://example.com/privacy-policy" + # keywords: + # - lang: en + # text: ["Matrix", "Element"] + # logo: + # - lang: en + # text: "https://example.com/logo.svg" + # width: "200" + # height: "80" + + #organization: + # name: Example com + # display_name: + # - ["Example co", "en"] + # url: "http://example.com" + + #contact_person: + # - given_name: Bob + # sur_name: "the Sysadmin" + # email_address": ["admin@example.com"] + # contact_type": technical # Instead of putting the config inline as above, you can specify a # separate pysaml2 configuration file: @@ -1640,12 +1700,19 @@ saml2_config: # - attribute: department # value: "sales" + # If the metadata XML contains multiple IdP entities then the `idp_entityid` + # option must be set to the entity to redirect users to. + # + # Most deployments only have a single IdP entity and so should omit this + # option. + # + #idp_entityid: 'https://our_idp/entityid' -# OpenID Connect integration. The following settings can be used to make Synapse -# use an OpenID Connect Provider for authentication, instead of its internal -# password database. + +# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. # -# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md. +# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md +# for some example configurations. # oidc_config: # Uncomment the following to enable authorization against an OpenID Connect @@ -1714,6 +1781,14 @@ oidc_config: # #skip_verification: true + # Whether to fetch the user profile from the userinfo endpoint. Valid + # values are: "auto" or "userinfo_endpoint". + # + # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included + # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. + # + #user_profile_method: "userinfo_endpoint" + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead # of failing. This could be used if switching from password logins to OIDC. Defaults to false. # @@ -1750,9 +1825,10 @@ oidc_config: # * user: The claims returned by the UserInfo Endpoint and/or in the ID # Token # - # This must be configured if using the default mapping provider. + # If this is not set, the user will be prompted to choose their + # own username. # - localpart_template: "{{ user.preferred_username }}" + #localpart_template: "{{ user.preferred_username }}" # Jinja2 template for the display name to set on first login. # @@ -1770,15 +1846,37 @@ oidc_config: -# Enable CAS for registration and login. +# Enable Central Authentication Service (CAS) for registration and login. # -#cas_config: -# enabled: true -# server_url: "https://cas-server.com" -# service_url: "https://homeserver.domain.com:8448" -# #displayname_attribute: name -# #required_attributes: -# # name: value +cas_config: + # Uncomment the following to enable authorization against a CAS server. + # Defaults to false. + # + #enabled: true + + # The URL of the CAS authorization endpoint. + # + #server_url: "https://cas-server.com" + + # The public URL of the homeserver. + # + #service_url: "https://homeserver.domain.com:8448" + + # The attribute of the CAS response to use as the display name. + # + # If unset, no displayname will be set. + # + #displayname_attribute: name + + # It is possible to configure Synapse to only allow logins if CAS attributes + # match particular values. All of the keys in the mapping below must exist + # and the values must match the given value. Alternately if the given value + # is None then any value is allowed (the attribute just must exist). + # All of the listed attributes must match for the login to be permitted. + # + #required_attributes: + # userGroup: "staff" + # department: None # Additional settings to use with single-sign on systems such as OpenID Connect, @@ -1806,11 +1904,8 @@ sso: # - https://my.custom.client/ # Directory in which Synapse will try to find the template files below. - # If not set, default templates from within the Synapse package will be used. - # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # If not set, or the files named below are not found within the template + # directory, default templates from within the Synapse package will be used. # # Synapse will look for the following templates in this directory: # @@ -1878,7 +1973,7 @@ sso: # and issued at ("iat") claims are validated if present. # # Note that this is a non-standard login type and client support is -# expected to be non-existant. +# expected to be non-existent. # # See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md. # @@ -1974,6 +2069,21 @@ password_config: # #require_uppercase: true +ui_auth: + # The number of milliseconds to allow a user-interactive authentication + # session to be active. + # + # This defaults to 0, meaning the user is queried for their credentials + # before every action, but this can be overridden to alow a single + # validation to be re-used. This weakens the protections afforded by + # the user-interactive authentication process, by allowing for multiple + # (and potentially different) operations to use the same validation session. + # + # Uncomment below to allow for credential validation to last for 15 + # seconds. + # + #session_timeout: 15000 + # Configuration for sending emails from Synapse. # @@ -2039,10 +2149,15 @@ email: # #validation_token_lifetime: 15m - # Directory in which Synapse will try to find the template files below. - # If not set, default templates from within the Synapse package will be used. + # The web client location to direct users to during an invite. This is passed + # to the identity server as the org.matrix.web_client_location key. Defaults + # to unset, giving no guidance to the identity server. # - # Do not uncomment this setting unless you want to customise the templates. + #invite_client_location: https://app.element.io + + # Directory in which Synapse will try to find the template files below. + # If not set, or the files named below are not found within the template + # directory, default templates from within the Synapse package will be used. # # Synapse will look for the following templates in this directory: # @@ -2180,20 +2295,35 @@ password_providers: -# Clients requesting push notifications can either have the body of -# the message sent in the notification poke along with other details -# like the sender, or just the event ID and room ID (`event_id_only`). -# If clients choose the former, this option controls whether the -# notification request includes the content of the event (other details -# like the sender are still included). For `event_id_only` push, it -# has no effect. -# -# For modern android devices the notification content will still appear -# because it is loaded by the app. iPhone, however will send a -# notification saying only that a message arrived and who it came from. -# -#push: -# include_content: true +## Push ## + +push: + # Clients requesting push notifications can either have the body of + # the message sent in the notification poke along with other details + # like the sender, or just the event ID and room ID (`event_id_only`). + # If clients choose the former, this option controls whether the + # notification request includes the content of the event (other details + # like the sender are still included). For `event_id_only` push, it + # has no effect. + # + # For modern android devices the notification content will still appear + # because it is loaded by the app. iPhone, however will send a + # notification saying only that a message arrived and who it came from. + # + # The default value is "true" to include message details. Uncomment to only + # include the event ID and room ID in push notification payloads. + # + #include_content: false + + # When a push notification is received, an unread count is also sent. + # This number can either be calculated as the number of unread messages + # for the user, or the number of *rooms* the user has unread messages in. + # + # The default value is "true", meaning push clients will see the number of + # rooms with unread messages in them. Uncomment to instead send the number + # of unread messages. + # + #group_unread_count_by_room: false # Spam checkers are third-party modules that can block specific actions @@ -2236,7 +2366,7 @@ spam_checker: # If enabled, non server admins can only create groups with local parts # starting with this prefix # -#group_creation_prefix: "unofficial/" +#group_creation_prefix: "unofficial_" @@ -2394,7 +2524,7 @@ spam_checker: # # Options for the rules include: # -# user_id: Matches agaisnt the creator of the alias +# user_id: Matches against the creator of the alias # room_id: Matches against the room ID being published # alias: Matches against any current local or canonical aliases # associated with the room @@ -2440,7 +2570,7 @@ opentracing: # This is a list of regexes which are matched against the server_name of the # homeserver. # - # By defult, it is empty, so no servers are matched. + # By default, it is empty, so no servers are matched. # #homeserver_whitelist: # - ".*" @@ -2496,6 +2626,18 @@ opentracing: # events: worker1 # typing: worker1 +# The worker that is used to run background tasks (e.g. cleaning up expired +# data). If not provided this defaults to the main process. +# +#run_background_tasks_on: worker1 + +# A shared secret used by the replication APIs to authenticate HTTP requests +# from workers. +# +# By default this is unused and traffic is not authenticated. +# +#worker_replication_secret: "" + # Configuration for Redis when using workers. This *must* be enabled when # using workers (unless using old style direct TCP configuration). diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml
index 55a48a9ed6..ff3c747180 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml
@@ -3,7 +3,11 @@ # This is a YAML file containing a standard Python logging configuration # dictionary. See [1] for details on the valid settings. # +# Synapse also supports structured logging for machine readable logs which can +# be ingested by ELK stacks. See [2] for details. +# # [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema +# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md version: 1 @@ -59,7 +63,7 @@ root: # then write them to a file. # # Replace "buffer" with "console" to log to stderr instead. (Note that you'll - # also need to update the configuation for the `twisted` logger above, in + # also need to update the configuration for the `twisted` logger above, in # this case.) # handlers: [buffer] diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index eb10e115f9..5b4f6428e6 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md
@@ -11,7 +11,7 @@ able to be imported by the running Synapse. The Python class is instantiated with two objects: * Any configuration (see below). -* An instance of `synapse.spam_checker_api.SpamCheckerApi`. +* An instance of `synapse.module_api.ModuleApi`. It then implements methods which return a boolean to alter behavior in Synapse. @@ -22,41 +22,45 @@ well as some specific methods: * `user_may_create_room` * `user_may_create_room_alias` * `user_may_publish_room` +* `check_username_for_spam` +* `check_registration_for_spam` The details of the each of these methods (as well as their inputs and outputs) are documented in the `synapse.events.spamcheck.SpamChecker` class. -The `SpamCheckerApi` class provides a way for the custom spam checker class to -call back into the homeserver internals. It currently implements the following -methods: - -* `get_state_events_in_room` +The `ModuleApi` class provides a way for the custom spam checker class to +call back into the homeserver internals. ### Example ```python +from synapse.spam_checker_api import RegistrationBehaviour + class ExampleSpamChecker: def __init__(self, config, api): self.config = config self.api = api - def check_event_for_spam(self, foo): + async def check_event_for_spam(self, foo): return False # allow all events - def user_may_invite(self, inviter_userid, invitee_userid, room_id): + async def user_may_invite(self, inviter_userid, invitee_userid, room_id): return True # allow all invites - def user_may_create_room(self, userid): + async def user_may_create_room(self, userid): return True # allow all room creations - def user_may_create_room_alias(self, userid, room_alias): + async def user_may_create_room_alias(self, userid, room_alias): return True # allow all room aliases - def user_may_publish_room(self, userid, room_id): + async def user_may_publish_room(self, userid, room_id): return True # allow publishing of all rooms - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): return False # allow all usernames + + async def check_registration_for_spam(self, email_threepid, username, request_info): + return RegistrationBehaviour.ALLOW # allow all registrations ``` ## Configuration diff --git a/docs/sphinx/README.rst b/docs/sphinx/README.rst deleted file mode 100644
index a7ab7c5500..0000000000 --- a/docs/sphinx/README.rst +++ /dev/null
@@ -1 +0,0 @@ -TODO: how (if at all) is this actually maintained? diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py deleted file mode 100644
index ca4b879526..0000000000 --- a/docs/sphinx/conf.py +++ /dev/null
@@ -1,271 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Synapse documentation build configuration file, created by -# sphinx-quickstart on Tue Jun 10 17:31:02 2014. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -import sys -import os - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath("..")) - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -# needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.intersphinx", - "sphinx.ext.coverage", - "sphinx.ext.ifconfig", - "sphinxcontrib.napoleon", -] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] - -# The suffix of source filenames. -source_suffix = ".rst" - -# The encoding of source files. -# source_encoding = 'utf-8-sig' - -# The master toctree document. -master_doc = "index" - -# General information about the project. -project = "Synapse" -copyright = ( - "Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd" -) - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -version = "1.0" -# The full version, including alpha/beta/rc tags. -release = "1.0" - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# language = None - -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -# today = '' -# Else, today_fmt is used as the format for a strftime call. -# today_fmt = '%B %d, %Y' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -exclude_patterns = ["_build"] - -# The reST default role (used for this markup: `text`) to use for all -# documents. -# default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -# add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -# add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -# show_authors = False - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" - -# A list of ignored prefixes for module index sorting. -# modindex_common_prefix = [] - -# If true, keep warnings as "system message" paragraphs in the built documents. -# keep_warnings = False - - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -html_theme = "default" - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -# html_theme_options = {} - -# Add any paths that contain custom themes here, relative to this directory. -# html_theme_path = [] - -# The name for this set of Sphinx documents. If None, it defaults to -# "<project> v<release> documentation". -# html_title = None - -# A shorter title for the navigation bar. Default is the same as html_title. -# html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. -# html_logo = None - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. -# html_favicon = None - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] - -# Add any extra paths that contain custom files (such as robots.txt or -# .htaccess) here, relative to this directory. These files are copied -# directly to the root of the documentation. -# html_extra_path = [] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -# html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -# html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -# html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -# html_additional_pages = {} - -# If false, no module index is generated. -# html_domain_indices = True - -# If false, no index is generated. -# html_use_index = True - -# If true, the index is split into individual pages for each letter. -# html_split_index = False - -# If true, links to the reST sources are added to the pages. -# html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -# html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -# html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a <link> tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -# html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -# html_file_suffix = None - -# Output file base name for HTML help builder. -htmlhelp_basename = "Synapsedoc" - - -# -- Options for LaTeX output --------------------------------------------- - -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - #'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). - #'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. - #'preamble': '', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [("index", "Synapse.tex", "Synapse Documentation", "TNG", "manual")] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -# latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -# latex_use_parts = False - -# If true, show page references after internal links. -# latex_show_pagerefs = False - -# If true, show URL addresses after external links. -# latex_show_urls = False - -# Documents to append as an appendix to all manuals. -# latex_appendices = [] - -# If false, no module index is generated. -# latex_domain_indices = True - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [("index", "synapse", "Synapse Documentation", ["TNG"], 1)] - -# If true, show URL addresses after external links. -# man_show_urls = False - - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - ( - "index", - "Synapse", - "Synapse Documentation", - "TNG", - "Synapse", - "One line description of project.", - "Miscellaneous", - ) -] - -# Documents to append as an appendix to all manuals. -# texinfo_appendices = [] - -# If false, no module index is generated. -# texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -# texinfo_show_urls = 'footnote' - -# If true, do not generate a @detailmenu in the "Top" node's menu. -# texinfo_no_detailmenu = False - - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {"http://docs.python.org/": None} - -napoleon_include_special_with_doc = True -napoleon_use_ivar = True diff --git a/docs/sphinx/index.rst b/docs/sphinx/index.rst deleted file mode 100644
index 76a4c0c7bf..0000000000 --- a/docs/sphinx/index.rst +++ /dev/null
@@ -1,20 +0,0 @@ -.. Synapse documentation master file, created by - sphinx-quickstart on Tue Jun 10 17:31:02 2014. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -Welcome to Synapse's documentation! -=================================== - -Contents: - -.. toctree:: - synapse - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` - diff --git a/docs/sphinx/modules.rst b/docs/sphinx/modules.rst deleted file mode 100644
index 1c7f70bd13..0000000000 --- a/docs/sphinx/modules.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse -======= - -.. toctree:: - :maxdepth: 4 - - synapse diff --git a/docs/sphinx/synapse.api.auth.rst b/docs/sphinx/synapse.api.auth.rst deleted file mode 100644
index 931eb59836..0000000000 --- a/docs/sphinx/synapse.api.auth.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.auth module -======================= - -.. automodule:: synapse.api.auth - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.constants.rst b/docs/sphinx/synapse.api.constants.rst deleted file mode 100644
index a1e3c47f68..0000000000 --- a/docs/sphinx/synapse.api.constants.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.constants module -============================ - -.. automodule:: synapse.api.constants - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.dbobjects.rst b/docs/sphinx/synapse.api.dbobjects.rst deleted file mode 100644
index e9d31167e0..0000000000 --- a/docs/sphinx/synapse.api.dbobjects.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.dbobjects module -============================ - -.. automodule:: synapse.api.dbobjects - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.errors.rst b/docs/sphinx/synapse.api.errors.rst deleted file mode 100644
index f1c6881478..0000000000 --- a/docs/sphinx/synapse.api.errors.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.errors module -========================= - -.. automodule:: synapse.api.errors - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.event_stream.rst b/docs/sphinx/synapse.api.event_stream.rst deleted file mode 100644
index 9291cb2dbc..0000000000 --- a/docs/sphinx/synapse.api.event_stream.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.event_stream module -=============================== - -.. automodule:: synapse.api.event_stream - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.events.factory.rst b/docs/sphinx/synapse.api.events.factory.rst deleted file mode 100644
index 2e71ff6070..0000000000 --- a/docs/sphinx/synapse.api.events.factory.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.events.factory module -================================= - -.. automodule:: synapse.api.events.factory - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.events.room.rst b/docs/sphinx/synapse.api.events.room.rst deleted file mode 100644
index 6cd5998599..0000000000 --- a/docs/sphinx/synapse.api.events.room.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.events.room module -============================== - -.. automodule:: synapse.api.events.room - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.events.rst b/docs/sphinx/synapse.api.events.rst deleted file mode 100644
index b762da55ee..0000000000 --- a/docs/sphinx/synapse.api.events.rst +++ /dev/null
@@ -1,18 +0,0 @@ -synapse.api.events package -========================== - -Submodules ----------- - -.. toctree:: - - synapse.api.events.factory - synapse.api.events.room - -Module contents ---------------- - -.. automodule:: synapse.api.events - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.handlers.events.rst b/docs/sphinx/synapse.api.handlers.events.rst deleted file mode 100644
index d2e1b54ac0..0000000000 --- a/docs/sphinx/synapse.api.handlers.events.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.handlers.events module -================================== - -.. automodule:: synapse.api.handlers.events - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.handlers.factory.rst b/docs/sphinx/synapse.api.handlers.factory.rst deleted file mode 100644
index b04a93f740..0000000000 --- a/docs/sphinx/synapse.api.handlers.factory.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.handlers.factory module -=================================== - -.. automodule:: synapse.api.handlers.factory - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.handlers.federation.rst b/docs/sphinx/synapse.api.handlers.federation.rst deleted file mode 100644
index 61a6542210..0000000000 --- a/docs/sphinx/synapse.api.handlers.federation.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.handlers.federation module -====================================== - -.. automodule:: synapse.api.handlers.federation - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.handlers.register.rst b/docs/sphinx/synapse.api.handlers.register.rst deleted file mode 100644
index 388f144eca..0000000000 --- a/docs/sphinx/synapse.api.handlers.register.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.handlers.register module -==================================== - -.. automodule:: synapse.api.handlers.register - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.handlers.room.rst b/docs/sphinx/synapse.api.handlers.room.rst deleted file mode 100644
index 8ca156c7ff..0000000000 --- a/docs/sphinx/synapse.api.handlers.room.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.handlers.room module -================================ - -.. automodule:: synapse.api.handlers.room - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.handlers.rst b/docs/sphinx/synapse.api.handlers.rst deleted file mode 100644
index e84f563fcb..0000000000 --- a/docs/sphinx/synapse.api.handlers.rst +++ /dev/null
@@ -1,21 +0,0 @@ -synapse.api.handlers package -============================ - -Submodules ----------- - -.. toctree:: - - synapse.api.handlers.events - synapse.api.handlers.factory - synapse.api.handlers.federation - synapse.api.handlers.register - synapse.api.handlers.room - -Module contents ---------------- - -.. automodule:: synapse.api.handlers - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.notifier.rst b/docs/sphinx/synapse.api.notifier.rst deleted file mode 100644
index 631b42a497..0000000000 --- a/docs/sphinx/synapse.api.notifier.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.notifier module -=========================== - -.. automodule:: synapse.api.notifier - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.register_events.rst b/docs/sphinx/synapse.api.register_events.rst deleted file mode 100644
index 79ad4ce211..0000000000 --- a/docs/sphinx/synapse.api.register_events.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.register_events module -================================== - -.. automodule:: synapse.api.register_events - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.room_events.rst b/docs/sphinx/synapse.api.room_events.rst deleted file mode 100644
index bead1711f5..0000000000 --- a/docs/sphinx/synapse.api.room_events.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.room_events module -============================== - -.. automodule:: synapse.api.room_events - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.rst b/docs/sphinx/synapse.api.rst deleted file mode 100644
index f4d39ff331..0000000000 --- a/docs/sphinx/synapse.api.rst +++ /dev/null
@@ -1,30 +0,0 @@ -synapse.api package -=================== - -Subpackages ------------ - -.. toctree:: - - synapse.api.events - synapse.api.handlers - synapse.api.streams - -Submodules ----------- - -.. toctree:: - - synapse.api.auth - synapse.api.constants - synapse.api.errors - synapse.api.notifier - synapse.api.storage - -Module contents ---------------- - -.. automodule:: synapse.api - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.server.rst b/docs/sphinx/synapse.api.server.rst deleted file mode 100644
index b01600235e..0000000000 --- a/docs/sphinx/synapse.api.server.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.server module -========================= - -.. automodule:: synapse.api.server - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.storage.rst b/docs/sphinx/synapse.api.storage.rst deleted file mode 100644
index afa40685c4..0000000000 --- a/docs/sphinx/synapse.api.storage.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.storage module -========================== - -.. automodule:: synapse.api.storage - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.stream.rst b/docs/sphinx/synapse.api.stream.rst deleted file mode 100644
index 0d5e3f01bf..0000000000 --- a/docs/sphinx/synapse.api.stream.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.stream module -========================= - -.. automodule:: synapse.api.stream - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.streams.event.rst b/docs/sphinx/synapse.api.streams.event.rst deleted file mode 100644
index 2ac45a35c8..0000000000 --- a/docs/sphinx/synapse.api.streams.event.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.api.streams.event module -================================ - -.. automodule:: synapse.api.streams.event - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.api.streams.rst b/docs/sphinx/synapse.api.streams.rst deleted file mode 100644
index 72eb205caf..0000000000 --- a/docs/sphinx/synapse.api.streams.rst +++ /dev/null
@@ -1,17 +0,0 @@ -synapse.api.streams package -=========================== - -Submodules ----------- - -.. toctree:: - - synapse.api.streams.event - -Module contents ---------------- - -.. automodule:: synapse.api.streams - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.app.homeserver.rst b/docs/sphinx/synapse.app.homeserver.rst deleted file mode 100644
index 54b93da8fe..0000000000 --- a/docs/sphinx/synapse.app.homeserver.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.app.homeserver module -============================= - -.. automodule:: synapse.app.homeserver - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.app.rst b/docs/sphinx/synapse.app.rst deleted file mode 100644
index 4535b79827..0000000000 --- a/docs/sphinx/synapse.app.rst +++ /dev/null
@@ -1,17 +0,0 @@ -synapse.app package -=================== - -Submodules ----------- - -.. toctree:: - - synapse.app.homeserver - -Module contents ---------------- - -.. automodule:: synapse.app - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.db.rst b/docs/sphinx/synapse.db.rst deleted file mode 100644
index 83df6c03db..0000000000 --- a/docs/sphinx/synapse.db.rst +++ /dev/null
@@ -1,10 +0,0 @@ -synapse.db package -================== - -Module contents ---------------- - -.. automodule:: synapse.db - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.federation.handler.rst b/docs/sphinx/synapse.federation.handler.rst deleted file mode 100644
index 5597f5c46d..0000000000 --- a/docs/sphinx/synapse.federation.handler.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.federation.handler module -================================= - -.. automodule:: synapse.federation.handler - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.federation.messaging.rst b/docs/sphinx/synapse.federation.messaging.rst deleted file mode 100644
index 4bbaabf3ef..0000000000 --- a/docs/sphinx/synapse.federation.messaging.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.federation.messaging module -=================================== - -.. automodule:: synapse.federation.messaging - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.federation.pdu_codec.rst b/docs/sphinx/synapse.federation.pdu_codec.rst deleted file mode 100644
index 8f0b15a63c..0000000000 --- a/docs/sphinx/synapse.federation.pdu_codec.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.federation.pdu_codec module -=================================== - -.. automodule:: synapse.federation.pdu_codec - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.federation.persistence.rst b/docs/sphinx/synapse.federation.persistence.rst deleted file mode 100644
index db7ab8ade1..0000000000 --- a/docs/sphinx/synapse.federation.persistence.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.federation.persistence module -===================================== - -.. automodule:: synapse.federation.persistence - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.federation.replication.rst b/docs/sphinx/synapse.federation.replication.rst deleted file mode 100644
index 49e26e0928..0000000000 --- a/docs/sphinx/synapse.federation.replication.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.federation.replication module -===================================== - -.. automodule:: synapse.federation.replication - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.federation.rst b/docs/sphinx/synapse.federation.rst deleted file mode 100644
index 7240c7901b..0000000000 --- a/docs/sphinx/synapse.federation.rst +++ /dev/null
@@ -1,22 +0,0 @@ -synapse.federation package -========================== - -Submodules ----------- - -.. toctree:: - - synapse.federation.handler - synapse.federation.pdu_codec - synapse.federation.persistence - synapse.federation.replication - synapse.federation.transport - synapse.federation.units - -Module contents ---------------- - -.. automodule:: synapse.federation - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.federation.transport.rst b/docs/sphinx/synapse.federation.transport.rst deleted file mode 100644
index 877956b3c9..0000000000 --- a/docs/sphinx/synapse.federation.transport.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.federation.transport module -=================================== - -.. automodule:: synapse.federation.transport - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.federation.units.rst b/docs/sphinx/synapse.federation.units.rst deleted file mode 100644
index 8f9212b07d..0000000000 --- a/docs/sphinx/synapse.federation.units.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.federation.units module -=============================== - -.. automodule:: synapse.federation.units - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.persistence.rst b/docs/sphinx/synapse.persistence.rst deleted file mode 100644
index 37c0c23720..0000000000 --- a/docs/sphinx/synapse.persistence.rst +++ /dev/null
@@ -1,19 +0,0 @@ -synapse.persistence package -=========================== - -Submodules ----------- - -.. toctree:: - - synapse.persistence.service - synapse.persistence.tables - synapse.persistence.transactions - -Module contents ---------------- - -.. automodule:: synapse.persistence - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.persistence.service.rst b/docs/sphinx/synapse.persistence.service.rst deleted file mode 100644
index 3514d3c76f..0000000000 --- a/docs/sphinx/synapse.persistence.service.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.persistence.service module -================================== - -.. automodule:: synapse.persistence.service - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.persistence.tables.rst b/docs/sphinx/synapse.persistence.tables.rst deleted file mode 100644
index 907b02769d..0000000000 --- a/docs/sphinx/synapse.persistence.tables.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.persistence.tables module -================================= - -.. automodule:: synapse.persistence.tables - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.persistence.transactions.rst b/docs/sphinx/synapse.persistence.transactions.rst deleted file mode 100644
index 475c02a8c5..0000000000 --- a/docs/sphinx/synapse.persistence.transactions.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.persistence.transactions module -======================================= - -.. automodule:: synapse.persistence.transactions - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.rest.base.rst b/docs/sphinx/synapse.rest.base.rst deleted file mode 100644
index 84d2d9b31d..0000000000 --- a/docs/sphinx/synapse.rest.base.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.rest.base module -======================== - -.. automodule:: synapse.rest.base - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.rest.events.rst b/docs/sphinx/synapse.rest.events.rst deleted file mode 100644
index ebbe26c746..0000000000 --- a/docs/sphinx/synapse.rest.events.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.rest.events module -========================== - -.. automodule:: synapse.rest.events - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.rest.register.rst b/docs/sphinx/synapse.rest.register.rst deleted file mode 100644
index a4a48a8a8f..0000000000 --- a/docs/sphinx/synapse.rest.register.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.rest.register module -============================ - -.. automodule:: synapse.rest.register - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.rest.room.rst b/docs/sphinx/synapse.rest.room.rst deleted file mode 100644
index 63fc5c2840..0000000000 --- a/docs/sphinx/synapse.rest.room.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.rest.room module -======================== - -.. automodule:: synapse.rest.room - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.rest.rst b/docs/sphinx/synapse.rest.rst deleted file mode 100644
index 016af926b2..0000000000 --- a/docs/sphinx/synapse.rest.rst +++ /dev/null
@@ -1,20 +0,0 @@ -synapse.rest package -==================== - -Submodules ----------- - -.. toctree:: - - synapse.rest.base - synapse.rest.events - synapse.rest.register - synapse.rest.room - -Module contents ---------------- - -.. automodule:: synapse.rest - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.rst b/docs/sphinx/synapse.rst deleted file mode 100644
index e7869e0e5d..0000000000 --- a/docs/sphinx/synapse.rst +++ /dev/null
@@ -1,30 +0,0 @@ -synapse package -=============== - -Subpackages ------------ - -.. toctree:: - - synapse.api - synapse.app - synapse.federation - synapse.persistence - synapse.rest - synapse.util - -Submodules ----------- - -.. toctree:: - - synapse.server - synapse.state - -Module contents ---------------- - -.. automodule:: synapse - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.server.rst b/docs/sphinx/synapse.server.rst deleted file mode 100644
index 7f33f084d7..0000000000 --- a/docs/sphinx/synapse.server.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.server module -===================== - -.. automodule:: synapse.server - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.state.rst b/docs/sphinx/synapse.state.rst deleted file mode 100644
index 744be2a8be..0000000000 --- a/docs/sphinx/synapse.state.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.state module -==================== - -.. automodule:: synapse.state - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.util.async.rst b/docs/sphinx/synapse.util.async.rst deleted file mode 100644
index 542bb54444..0000000000 --- a/docs/sphinx/synapse.util.async.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.util.async module -========================= - -.. automodule:: synapse.util.async - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.util.dbutils.rst b/docs/sphinx/synapse.util.dbutils.rst deleted file mode 100644
index afaa9eb749..0000000000 --- a/docs/sphinx/synapse.util.dbutils.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.util.dbutils module -=========================== - -.. automodule:: synapse.util.dbutils - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.util.http.rst b/docs/sphinx/synapse.util.http.rst deleted file mode 100644
index 344af5a490..0000000000 --- a/docs/sphinx/synapse.util.http.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.util.http module -======================== - -.. automodule:: synapse.util.http - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.util.lockutils.rst b/docs/sphinx/synapse.util.lockutils.rst deleted file mode 100644
index 16ee26cabd..0000000000 --- a/docs/sphinx/synapse.util.lockutils.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.util.lockutils module -============================= - -.. automodule:: synapse.util.lockutils - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.util.logutils.rst b/docs/sphinx/synapse.util.logutils.rst deleted file mode 100644
index 2b79fa7a4b..0000000000 --- a/docs/sphinx/synapse.util.logutils.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.util.logutils module -============================ - -.. automodule:: synapse.util.logutils - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.util.rst b/docs/sphinx/synapse.util.rst deleted file mode 100644
index 01a0c3a591..0000000000 --- a/docs/sphinx/synapse.util.rst +++ /dev/null
@@ -1,21 +0,0 @@ -synapse.util package -==================== - -Submodules ----------- - -.. toctree:: - - synapse.util.async - synapse.util.http - synapse.util.lockutils - synapse.util.logutils - synapse.util.stringutils - -Module contents ---------------- - -.. automodule:: synapse.util - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sphinx/synapse.util.stringutils.rst b/docs/sphinx/synapse.util.stringutils.rst deleted file mode 100644
index ec626eee28..0000000000 --- a/docs/sphinx/synapse.util.stringutils.rst +++ /dev/null
@@ -1,7 +0,0 @@ -synapse.util.stringutils module -=============================== - -.. automodule:: synapse.util.stringutils - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index 32b06aa2c5..e1d6ede7ba 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md
@@ -15,8 +15,21 @@ where SAML mapping providers come into play. SSO mapping providers are currently supported for OpenID and SAML SSO configurations. Please see the details below for how to implement your own. +It is up to the mapping provider whether the user should be assigned a predefined +Matrix ID based on the SSO attributes, or if the user should be allowed to +choose their own username. + +In the first case - where users are automatically allocated a Matrix ID - it is +the responsibility of the mapping provider to normalise the SSO attributes and +map them to a valid Matrix ID. The [specification for Matrix +IDs](https://matrix.org/docs/spec/appendices#user-identifiers) has some +information about what is considered valid. + +If the mapping provider does not assign a Matrix ID, then Synapse will +automatically serve an HTML page allowing the user to pick their own username. + External mapping providers are provided to Synapse in the form of an external -Python module. You can retrieve this module from [PyPi](https://pypi.org) or elsewhere, +Python module. You can retrieve this module from [PyPI](https://pypi.org) or elsewhere, but it must be importable via Synapse (e.g. it must be in the same virtualenv as Synapse). The Synapse config is then modified to point to the mapping provider (and optionally provide additional configuration for it). @@ -56,16 +69,26 @@ A custom mapping provider must specify the following methods: information from. - This method must return a string, which is the unique identifier for the user. Commonly the ``sub`` claim of the response. -* `map_user_attributes(self, userinfo, token)` +* `map_user_attributes(self, userinfo, token, failures)` - This method must be async. - Arguments: - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user information from. - `token` - A dictionary which includes information necessary to make further requests to the OpenID provider. + - `failures` - An `int` that represents the amount of times the returned + mxid localpart mapping has failed. This should be used + to create a deduplicated mxid localpart which should be + returned instead. For example, if this method returns + `john.doe` as the value of `localpart` in the returned + dict, and that is already taken on the homeserver, this + method will be called again with the same parameters but + with failures=1. The method should then return a different + `localpart` value, such as `john.doe1`. - Returns a dictionary with two keys: - - localpart: A required string, used to generate the Matrix ID. - - displayname: An optional string, the display name for the user. + - `localpart`: A string, used to generate the Matrix ID. If this is + `None`, the user is prompted to pick their own username. + - `displayname`: An optional string, the display name for the user. * `get_extra_attributes(self, userinfo, token)` - This method must be async. - Arguments: @@ -100,11 +123,13 @@ comment these options out and use those specified by the module instead. A custom mapping provider must specify the following methods: -* `__init__(self, parsed_config)` +* `__init__(self, parsed_config, module_api)` - Arguments: - `parsed_config` - A configuration object that is the return value of the `parse_config` method. You should set any configuration options needed by the module here. + - `module_api` - a `synapse.module_api.ModuleApi` object which provides the + stable API available for extension modules. * `parse_config(config)` - This method should have the `@staticmethod` decoration. - Arguments: @@ -147,12 +172,20 @@ A custom mapping provider must specify the following methods: redirected to. - This method must return a dictionary, which will then be used by Synapse to build a new user. The following keys are allowed: - * `mxid_localpart` - Required. The mxid localpart of the new user. + * `mxid_localpart` - The mxid localpart of the new user. If this is + `None`, the user is prompted to pick their own username. * `displayname` - The displayname of the new user. If not provided, will default to the value of `mxid_localpart`. * `emails` - A list of emails for the new user. If not provided, will default to an empty list. + Alternatively it can raise a `synapse.api.errors.RedirectException` to + redirect the user to another page. This is useful to prompt the user for + additional information, e.g. if you want them to provide their own username. + It is the responsibility of the mapping provider to either redirect back + to `client_redirect_url` (including any additional information) or to + complete registration using methods from the `ModuleApi`. + ### Default SAML Mapping Provider Synapse has a built-in SAML mapping provider if a custom provider isn't diff --git a/docs/structured_logging.md b/docs/structured_logging.md
index decec9b8fa..b1281667e0 100644 --- a/docs/structured_logging.md +++ b/docs/structured_logging.md
@@ -1,83 +1,161 @@ # Structured Logging -A structured logging system can be useful when your logs are destined for a machine to parse and process. By maintaining its machine-readable characteristics, it enables more efficient searching and aggregations when consumed by software such as the "ELK stack". +A structured logging system can be useful when your logs are destined for a +machine to parse and process. By maintaining its machine-readable characteristics, +it enables more efficient searching and aggregations when consumed by software +such as the "ELK stack". -Synapse's structured logging system is configured via the file that Synapse's `log_config` config option points to. The file must be YAML and contain `structured: true`. It must contain a list of "drains" (places where logs go to). +Synapse's structured logging system is configured via the file that Synapse's +`log_config` config option points to. The file should include a formatter which +uses the `synapse.logging.TerseJsonFormatter` class included with Synapse and a +handler which uses the above formatter. + +There is also a `synapse.logging.JsonFormatter` option which does not include +a timestamp in the resulting JSON. This is useful if the log ingester adds its +own timestamp. A structured logging configuration looks similar to the following: ```yaml -structured: true +version: 1 + +formatters: + structured: + class: synapse.logging.TerseJsonFormatter + +handlers: + file: + class: logging.handlers.TimedRotatingFileHandler + formatter: structured + filename: /path/to/my/logs/homeserver.log + when: midnight + backupCount: 3 # Does not include the current log file. + encoding: utf8 loggers: synapse: level: INFO + handlers: [remote] synapse.storage.SQL: level: WARNING - -drains: - console: - type: console - location: stdout - file: - type: file_json - location: homeserver.log ``` -The above logging config will set Synapse as 'INFO' logging level by default, with the SQL layer at 'WARNING', and will have two logging drains (to the console and to a file, stored as JSON). - -## Drain Types +The above logging config will set Synapse as 'INFO' logging level by default, +with the SQL layer at 'WARNING', and will log to a file, stored as JSON. -Drain types can be specified by the `type` key. +It is also possible to figure Synapse to log to a remote endpoint by using the +`synapse.logging.RemoteHandler` class included with Synapse. It takes the +following arguments: -### `console` +- `host`: Hostname or IP address of the log aggregator. +- `port`: Numerical port to contact on the host. +- `maximum_buffer`: (Optional, defaults to 1000) The maximum buffer size to allow. -Outputs human-readable logs to the console. +A remote structured logging configuration looks similar to the following: -Arguments: +```yaml +version: 1 -- `location`: Either `stdout` or `stderr`. +formatters: + structured: + class: synapse.logging.TerseJsonFormatter -### `console_json` +handlers: + remote: + class: synapse.logging.RemoteHandler + formatter: structured + host: 10.1.2.3 + port: 9999 -Outputs machine-readable JSON logs to the console. +loggers: + synapse: + level: INFO + handlers: [remote] + synapse.storage.SQL: + level: WARNING +``` -Arguments: +The above logging config will set Synapse as 'INFO' logging level by default, +with the SQL layer at 'WARNING', and will log JSON formatted messages to a +remote endpoint at 10.1.2.3:9999. -- `location`: Either `stdout` or `stderr`. +## Upgrading from legacy structured logging configuration -### `console_json_terse` +Versions of Synapse prior to v1.23.0 included a custom structured logging +configuration which is deprecated. It used a `structured: true` flag and +configured `drains` instead of ``handlers`` and `formatters`. -Outputs machine-readable JSON logs to the console, separated by newlines. This -format is not designed to be read and re-formatted into human-readable text, but -is optimal for a logging aggregation system. +Synapse currently automatically converts the old configuration to the new +configuration, but this will be removed in a future version of Synapse. The +following reference can be used to update your configuration. Based on the drain +`type`, we can pick a new handler: -Arguments: +1. For a type of `console`, `console_json`, or `console_json_terse`: a handler + with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout` + or `ext://sys.stderr` should be used. +2. For a type of `file` or `file_json`: a handler of `logging.FileHandler` with + a location of the file path should be used. +3. For a type of `network_json_terse`: a handler of `synapse.logging.RemoteHandler` + with the host and port should be used. -- `location`: Either `stdout` or `stderr`. +Then based on the drain `type` we can pick a new formatter: -### `file` +1. For a type of `console` or `file` no formatter is necessary. +2. For a type of `console_json` or `file_json`: a formatter of + `synapse.logging.JsonFormatter` should be used. +3. For a type of `console_json_terse` or `network_json_terse`: a formatter of + `synapse.logging.TerseJsonFormatter` should be used. -Outputs human-readable logs to a file. +For each new handler and formatter they should be added to the logging configuration +and then assigned to either a logger or the root logger. -Arguments: +An example legacy configuration: -- `location`: An absolute path to the file to log to. +```yaml +structured: true -### `file_json` +loggers: + synapse: + level: INFO + synapse.storage.SQL: + level: WARNING -Outputs machine-readable logs to a file. +drains: + console: + type: console + location: stdout + file: + type: file_json + location: homeserver.log +``` -Arguments: +Would be converted into a new configuration: -- `location`: An absolute path to the file to log to. +```yaml +version: 1 -### `network_json_terse` +formatters: + json: + class: synapse.logging.JsonFormatter -Delivers machine-readable JSON logs to a log aggregator over TCP. This is -compatible with LogStash's TCP input with the codec set to `json_lines`. +handlers: + console: + class: logging.StreamHandler + location: ext://sys.stdout + file: + class: logging.FileHandler + formatter: json + filename: homeserver.log -Arguments: +loggers: + synapse: + level: INFO + handlers: [console, file] + synapse.storage.SQL: + level: WARNING +``` -- `host`: Hostname or IP address of the log aggregator. -- `port`: Numerical port to contact on the host. \ No newline at end of file +The new logging configuration is a bit more verbose, but significantly more +flexible. It allows for configuration that were not previously possible, such as +sending plain logs over the network, or using different handlers for different +modules. diff --git a/docs/systemd-with-workers/README.md b/docs/systemd-with-workers/README.md
index 257c09446f..8e57d4f62e 100644 --- a/docs/systemd-with-workers/README.md +++ b/docs/systemd-with-workers/README.md
@@ -37,10 +37,10 @@ synapse master process to be started as part of the `matrix-synapse.target` target. 1. For each worker process to be enabled, run `systemctl enable matrix-synapse-worker@<worker_name>.service`. For each `<worker_name>`, there -should be a corresponding configuration file +should be a corresponding configuration file. `/etc/matrix-synapse/workers/<worker_name>.yaml`. 1. Start all the synapse processes with `systemctl start matrix-synapse.target`. -1. Tell systemd to start synapse on boot with `systemctl enable matrix-synapse.target`/ +1. Tell systemd to start synapse on boot with `systemctl enable matrix-synapse.target`. ## Usage diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index db318baa9d..ad145439b4 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md
@@ -15,7 +15,7 @@ example flow would be (where '>' indicates master to worker and > SERVER example.com < REPLICATE - > POSITION events master 53 + > POSITION events master 53 53 > RDATA events master 54 ["$foo1:bar.com", ...] > RDATA events master 55 ["$foo4:bar.com", ...] @@ -138,9 +138,9 @@ the wire: < NAME synapse.app.appservice < PING 1490197665618 < REPLICATE - > POSITION events master 1 - > POSITION backfill master 1 - > POSITION caches master 1 + > POSITION events master 1 1 + > POSITION backfill master 1 1 + > POSITION caches master 1 1 > RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] > RDATA events master 14 ["$149019767112vOHxz:localhost:8823", "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] @@ -185,6 +185,11 @@ client (C): updates via HTTP API, rather than via the DB, then processes should make the request to the appropriate process. + Two positions are included, the "new" position and the last position sent respectively. + This allows servers to tell instances that the positions have advanced but no + data has been written, without clients needlessly checking to see if they + have missed any updates. + #### ERROR (S, C) There was an error diff --git a/docs/turn-howto.md b/docs/turn-howto.md
index d4a726be66..a470c274a5 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md
@@ -42,10 +42,10 @@ This will install and start a systemd service called `coturn`. ./configure - > You may need to install `libevent2`: if so, you should do so in - > the way recommended by your operating system. You can ignore - > warnings about lack of database support: a database is unnecessary - > for this purpose. + You may need to install `libevent2`: if so, you should do so in + the way recommended by your operating system. You can ignore + warnings about lack of database support: a database is unnecessary + for this purpose. 1. Build and install it: @@ -66,6 +66,19 @@ This will install and start a systemd service called `coturn`. pwgen -s 64 1 + A `realm` must be specified, but its value is somewhat arbitrary. (It is + sent to clients as part of the authentication flow.) It is conventional to + set it to be your server name. + +1. You will most likely want to configure coturn to write logs somewhere. The + easiest way is normally to send them to the syslog: + + syslog + + (in which case, the logs will be available via `journalctl -u coturn` on a + systemd system). Alternatively, coturn can be configured to write to a + logfile - check the example config file supplied with coturn. + 1. Consider your security settings. TURN lets users request a relay which will connect to arbitrary IP addresses and ports. The following configuration is suggested as a minimum starting point: @@ -96,11 +109,31 @@ This will install and start a systemd service called `coturn`. # TLS private key file pkey=/path/to/privkey.pem + In this case, replace the `turn:` schemes in the `turn_uri` settings below + with `turns:`. + + We recommend that you only try to set up TLS/DTLS once you have set up a + basic installation and got it working. + 1. Ensure your firewall allows traffic into the TURN server on the ports - you've configured it to listen on (By default: 3478 and 5349 for the TURN(s) + you've configured it to listen on (By default: 3478 and 5349 for TURN traffic (remember to allow both TCP and UDP traffic), and ports 49152-65535 for the UDP relay.) +1. We do not recommend running a TURN server behind NAT, and are not aware of + anyone doing so successfully. + + If you want to try it anyway, you will at least need to tell coturn its + external IP address: + + external-ip=192.88.99.1 + + ... and your NAT gateway must forward all of the relayed ports directly + (eg, port 56789 on the external IP must be always be forwarded to port + 56789 on the internal IP). + + If you get this working, let us know! + 1. (Re)start the turn server: * If you used the Debian package (or have set up a systemd unit yourself): @@ -137,9 +170,10 @@ Your home server configuration file needs the following extra keys: without having gone through a CAPTCHA or similar to register a real account. -As an example, here is the relevant section of the config file for matrix.org: +As an example, here is the relevant section of the config file for `matrix.org`. The +`turn_uris` are appropriate for TURN servers listening on the default ports, with no TLS. - turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ] + turn_uris: [ "turn:turn.matrix.org?transport=udp", "turn:turn.matrix.org?transport=tcp" ] turn_shared_secret: "n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons" turn_user_lifetime: 86400000 turn_allow_guests: True @@ -155,5 +189,86 @@ After updating the homeserver configuration, you must restart synapse: ``` systemctl restart synapse.service ``` +... and then reload any clients (or wait an hour for them to refresh their +settings). + +## Troubleshooting + +The normal symptoms of a misconfigured TURN server are that calls between +devices on different networks ring, but get stuck at "call +connecting". Unfortunately, troubleshooting this can be tricky. + +Here are a few things to try: + + * Check that your TURN server is not behind NAT. As above, we're not aware of + anyone who has successfully set this up. + + * Check that you have opened your firewall to allow TCP and UDP traffic to the + TURN ports (normally 3478 and 5479). + + * Check that you have opened your firewall to allow UDP traffic to the UDP + relay ports (49152-65535 by default). + + * Some WebRTC implementations (notably, that of Google Chrome) appear to get + confused by TURN servers which are reachable over IPv6 (this appears to be + an unexpected side-effect of its handling of multiple IP addresses as + defined by + [`draft-ietf-rtcweb-ip-handling`](https://tools.ietf.org/html/draft-ietf-rtcweb-ip-handling-12)). + + Try removing any AAAA records for your TURN server, so that it is only + reachable over IPv4. + + * Enable more verbose logging in coturn via the `verbose` setting: + + ``` + verbose + ``` + + ... and then see if there are any clues in its logs. + + * If you are using a browser-based client under Chrome, check + `chrome://webrtc-internals/` for insights into the internals of the + negotiation. On Firefox, check the "Connection Log" on `about:webrtc`. + + (Understanding the output is beyond the scope of this document!) + + * There is a WebRTC test tool at + https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To + use it, you will need a username/password for your TURN server. You can + either: + + * look for the `GET /_matrix/client/r0/voip/turnServer` request made by a + matrix client to your homeserver in your browser's network inspector. In + the response you should see `username` and `password`. Or: + + * Use the following shell commands: + + ```sh + secret=staticAuthSecretHere + + u=$((`date +%s` + 3600)):test + p=$(echo -n $u | openssl dgst -hmac $secret -sha1 -binary | base64) + echo -e "username: $u\npassword: $p" + ``` + + Or: + + * Temporarily configure coturn to accept a static username/password. To do + this, comment out `use-auth-secret` and `static-auth-secret` and add the + following: + + ``` + lt-cred-mech + user=username:password + ``` + + **Note**: these settings will not take effect unless `use-auth-secret` + and `static-auth-secret` are disabled. + + Restart coturn after changing the configuration file. + + Remember to restore the original settings to go back to testing with + Matrix clients! -..and your Home Server now supports VoIP relaying! + If the TURN server is working correctly, you should see at least one `relay` + entry in the results. diff --git a/docs/workers.md b/docs/workers.md
index ad4d8ca9f2..298adf8695 100644 --- a/docs/workers.md +++ b/docs/workers.md
@@ -89,7 +89,8 @@ shared configuration file. Normally, only a couple of changes are needed to make an existing configuration file suitable for use with workers. First, you need to enable an "HTTP replication listener" for the main process; and secondly, you need to enable redis-based -replication. For example: +replication. Optionally, a shared secret can be used to authenticate HTTP +traffic between workers. For example: ```yaml @@ -103,6 +104,9 @@ listeners: resources: - names: [replication] +# Add a random shared secret to authenticate traffic. +worker_replication_secret: "" + redis: enabled: true ``` @@ -116,7 +120,7 @@ public internet; it has no authentication and is unencrypted. ### Worker configuration In the config file for each worker, you must specify the type of worker -application (`worker_app`), and you should specify a unqiue name for the worker +application (`worker_app`), and you should specify a unique name for the worker (`worker_name`). The currently available worker applications are listed below. You must also specify the HTTP replication endpoint that it should talk to on the main synapse process. `worker_replication_host` should specify the host of @@ -225,6 +229,7 @@ expressions: ^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$ # Event sending requests + ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ @@ -262,6 +267,9 @@ using): Note that a HTTP listener with `client` and `federation` resources must be configured in the `worker_listeners` option in the worker config. +Ensure that all SSO logins go to a single process (usually the main process). +For multiple workers not handling the SSO endpoints properly, see +[#7530](https://github.com/matrix-org/synapse/issues/7530). #### Load balancing @@ -302,7 +310,7 @@ Additionally, there is *experimental* support for moving writing of specific streams (such as events) off of the main process to a particular worker. (This is only supported with Redis-based replication.) -Currently support streams are `events` and `typing`. +Currently supported streams are `events` and `typing`. To enable this, the worker must have a HTTP replication listener configured, have a `worker_name` and be listed in the `instance_map` config. For example to @@ -319,6 +327,35 @@ stream_writers: events: event_persister1 ``` +The `events` stream also experimentally supports having multiple writers, where +work is sharded between them by room ID. Note that you *must* restart all worker +instances when adding or removing event persisters. An example `stream_writers` +configuration with multiple writers: + +```yaml +stream_writers: + events: + - event_persister1 + - event_persister2 +``` + +#### Background tasks + +There is also *experimental* support for moving background tasks to a separate +worker. Background tasks are run periodically or started via replication. Exactly +which tasks are configured to run depends on your Synapse configuration (e.g. if +stats is enabled). + +To enable this, the worker must have a `worker_name` and can be configured to run +background tasks. For example, to move background tasks to a dedicated worker, +the shared configuration would include: + +```yaml +run_background_tasks_on: background_worker +``` + +You might also wish to investigate the `update_user_directory` and +`media_instance_running_background_jobs` settings. ### `synapse.app.pusher` @@ -391,6 +428,8 @@ and you must configure a single instance to run the background tasks, e.g.: media_instance_running_background_jobs: "media-repository-1" ``` +Note that if a reverse proxy is used , then `/_matrix/media/` must be routed for both inbound client and federation requests (if they are handled separately). + ### `synapse.app.user_dir` Handles searches in the user directory. It can handle REST endpoints matching diff --git a/mypy.ini b/mypy.ini
index 7986781432..5d15b7bf1c 100644 --- a/mypy.ini +++ b/mypy.ini
@@ -6,16 +6,33 @@ check_untyped_defs = True show_error_codes = True show_traceback = True mypy_path = stubs +warn_unreachable = True + +# To find all folders that pass mypy you run: +# +# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print + files = + scripts-dev/sign_json, synapse/api, synapse/appservice, synapse/config, + synapse/crypto, synapse/event_auth.py, synapse/events/builder.py, + synapse/events/validator.py, synapse/events/spamcheck.py, synapse/federation, + synapse/handlers/_base.py, + synapse/handlers/account_data.py, + synapse/handlers/account_validity.py, + synapse/handlers/admin.py, + synapse/handlers/appservice.py, synapse/handlers/auth.py, synapse/handlers/cas_handler.py, + synapse/handlers/deactivate_account.py, + synapse/handlers/device.py, + synapse/handlers/devicemessage.py, synapse/handlers/directory.py, synapse/handlers/events.py, synapse/handlers/federation.py, @@ -24,44 +41,68 @@ files = synapse/handlers/message.py, synapse/handlers/oidc_handler.py, synapse/handlers/pagination.py, + synapse/handlers/password_policy.py, synapse/handlers/presence.py, + synapse/handlers/profile.py, + synapse/handlers/read_marker.py, + synapse/handlers/receipts.py, + synapse/handlers/register.py, synapse/handlers/room.py, + synapse/handlers/room_list.py, synapse/handlers/room_member.py, synapse/handlers/room_member_worker.py, synapse/handlers/saml_handler.py, + synapse/handlers/sso.py, synapse/handlers/sync.py, + synapse/handlers/user_directory.py, synapse/handlers/ui_auth, + synapse/http/client.py, + synapse/http/federation/matrix_federation_agent.py, synapse/http/federation/well_known_resolver.py, + synapse/http/matrixfederationclient.py, synapse/http/server.py, synapse/http/site.py, synapse/logging, synapse/metrics, synapse/module_api, synapse/notifier.py, - synapse/push/pusherpool.py, - synapse/push/push_rule_evaluator.py, + synapse/push, synapse/replication, synapse/rest, synapse/server.py, synapse/server_notices, synapse/spam_checker_api, synapse/state, + synapse/storage/__init__.py, + synapse/storage/_base.py, + synapse/storage/background_updates.py, + synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, + synapse/storage/databases/main/keys.py, + synapse/storage/databases/main/pusher.py, + synapse/storage/databases/main/registration.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, synapse/storage/engines, + synapse/storage/keys.py, synapse/storage/persist_events.py, + synapse/storage/prepare_database.py, + synapse/storage/purge_events.py, + synapse/storage/push_rule.py, + synapse/storage/relations.py, + synapse/storage/roommember.py, synapse/storage/state.py, + synapse/storage/types.py, synapse/storage/util, synapse/streams, synapse/types.py, synapse/util/async_helpers.py, - synapse/util/caches/descriptors.py, - synapse/util/caches/stream_change_cache.py, + synapse/util/caches, synapse/util/metrics.py, tests/replication, tests/test_utils, + tests/handlers/test_password_providers.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_stream_change_cache.py @@ -71,6 +112,9 @@ ignore_missing_imports = True [mypy-zope] ignore_missing_imports = True +[mypy-bcrypt] +ignore_missing_imports = True + [mypy-constantly] ignore_missing_imports = True @@ -86,10 +130,13 @@ ignore_missing_imports = True [mypy-h11] ignore_missing_imports = True +[mypy-msgpack] +ignore_missing_imports = True + [mypy-opentracing] ignore_missing_imports = True -[mypy-OpenSSL] +[mypy-OpenSSL.*] ignore_missing_imports = True [mypy-netaddr] @@ -142,3 +189,6 @@ ignore_missing_imports = True [mypy-nacl.*] ignore_missing_imports = True + +[mypy-hiredis] +ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml
index db4a2e41e4..cd880d4e39 100644 --- a/pyproject.toml +++ b/pyproject.toml
@@ -35,7 +35,7 @@ showcontent = true [tool.black] -target-version = ['py34'] +target-version = ['py35'] exclude = ''' ( diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index 0647993658..f328ab57d5 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh
@@ -1,4 +1,4 @@ -#!/bin/sh +#!/bin/bash # # Runs linting scripts over the local Synapse checkout # isort - sorts import statements @@ -7,15 +7,91 @@ set -e -if [ $# -ge 1 ] -then - files=$* +usage() { + echo + echo "Usage: $0 [-h] [-d] [paths...]" + echo + echo "-d" + echo " Lint files that have changed since the last git commit." + echo + echo " If paths are provided and this option is set, both provided paths and those" + echo " that have changed since the last commit will be linted." + echo + echo " If no paths are provided and this option is not set, all files will be linted." + echo + echo " Note that paths with a file extension that is not '.py' will be excluded." + echo "-h" + echo " Display this help text." +} + +USING_DIFF=0 +files=() + +while getopts ":dh" opt; do + case $opt in + d) + USING_DIFF=1 + ;; + h) + usage + exit + ;; + \?) + echo "ERROR: Invalid option: -$OPTARG" >&2 + usage + exit + ;; + esac +done + +# Strip any options from the command line arguments now that +# we've finished processing them +shift "$((OPTIND-1))" + +if [ $USING_DIFF -eq 1 ]; then + # Check both staged and non-staged changes + for path in $(git diff HEAD --name-only); do + filename=$(basename "$path") + file_extension="${filename##*.}" + + # If an extension is present, and it's something other than 'py', + # then ignore this file + if [[ -n ${file_extension+x} && $file_extension != "py" ]]; then + continue + fi + + # Append this path to our list of files to lint + files+=("$path") + done +fi + +# Append any remaining arguments as files to lint +files+=("$@") + +if [[ $USING_DIFF -eq 1 ]]; then + # If we were asked to lint changed files, and no paths were found as a result... + if [ ${#files[@]} -eq 0 ]; then + # Then print and exit + echo "No files found to lint." + exit 0 + fi else - files="synapse tests scripts-dev scripts contrib synctl" + # If we were not asked to lint changed files, and no paths were found as a result, + # then lint everything! + if [[ -z ${files+x} ]]; then + # Lint all source code files and directories + files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark") + fi fi -echo "Linting these locations: $files" -isort $files -python3 -m black $files +echo "Linting these paths: ${files[*]}" +echo + +# Print out the commands being run +set -x + +isort "${files[@]}" +python3 -m black "${files[@]}" ./scripts-dev/config-lint.sh -flake8 $files +flake8 "${files[@]}" +mypy diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index a5b88731f1..f7f18805e4 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py
@@ -19,9 +19,10 @@ can crop up, e.g the cache descriptors. from typing import Callable, Optional +from mypy.nodes import ARG_NAMED_OPT from mypy.plugin import MethodSigContext, Plugin from mypy.typeops import bind_self -from mypy.types import CallableType +from mypy.types import CallableType, NoneType class SynapsePlugin(Plugin): @@ -30,6 +31,8 @@ class SynapsePlugin(Plugin): ) -> Optional[Callable[[MethodSigContext], CallableType]]: if fullname.startswith( "synapse.util.caches.descriptors._CachedFunction.__call__" + ) or fullname.startswith( + "synapse.util.caches.descriptors._LruCachedFunction.__call__" ): return cached_function_method_signature return None @@ -40,8 +43,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: It already has *almost* the correct signature, except: - 1. the `self` argument needs to be marked as "bound"; and - 2. any `cache_context` argument should be removed. + 1. the `self` argument needs to be marked as "bound"; + 2. any `cache_context` argument should be removed; + 3. an optional keyword argument `on_invalidated` should be added. """ # First we mark this as a bound function signature. @@ -58,19 +62,33 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: context_arg_index = idx break + arg_types = list(signature.arg_types) + arg_names = list(signature.arg_names) + arg_kinds = list(signature.arg_kinds) + if context_arg_index: - arg_types = list(signature.arg_types) arg_types.pop(context_arg_index) - - arg_names = list(signature.arg_names) arg_names.pop(context_arg_index) - - arg_kinds = list(signature.arg_kinds) arg_kinds.pop(context_arg_index) - signature = signature.copy_modified( - arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds, - ) + # Third, we add an optional "on_invalidate" argument. + # + # This is a callable which accepts no input and returns nothing. + calltyp = CallableType( + arg_types=[], + arg_kinds=[], + arg_names=[], + ret_type=NoneType(), + fallback=ctx.api.named_generic_type("builtins.function", []), + ) + + arg_types.append(calltyp) + arg_names.append("on_invalidate") + arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg. + + signature = signature.copy_modified( + arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds, + ) return signature diff --git a/scripts-dev/sign_json b/scripts-dev/sign_json new file mode 100755
index 0000000000..44553fb79a --- /dev/null +++ b/scripts-dev/sign_json
@@ -0,0 +1,127 @@ +#!/usr/bin/env python +# +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import sys +from json import JSONDecodeError + +import yaml +from signedjson.key import read_signing_keys +from signedjson.sign import sign_json + +from synapse.util import json_encoder + + +def main(): + parser = argparse.ArgumentParser( + description="""Adds a signature to a JSON object. + +Example usage: + + $ scripts-dev/sign_json.py -N test -k localhost.signing.key "{}" + {"signatures":{"test":{"ed25519:a_ZnZh":"LmPnml6iM0iR..."}}} +""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "-N", + "--server-name", + help="Name to give as the local homeserver. If unspecified, will be " + "read from the config file.", + ) + + parser.add_argument( + "-k", + "--signing-key-path", + help="Path to the file containing the private ed25519 key to sign the " + "request with.", + ) + + parser.add_argument( + "-c", + "--config", + default="homeserver.yaml", + help=( + "Path to synapse config file, from which the server name and/or signing " + "key path will be read. Ignored if --server-name and --signing-key-path " + "are both given." + ), + ) + + input_args = parser.add_mutually_exclusive_group() + + input_args.add_argument("input_data", nargs="?", help="Raw JSON to be signed.") + + input_args.add_argument( + "-i", + "--input", + type=argparse.FileType("r"), + default=sys.stdin, + help=( + "A file from which to read the JSON to be signed. If neither --input nor " + "input_data are given, JSON will be read from stdin." + ), + ) + + parser.add_argument( + "-o", + "--output", + type=argparse.FileType("w"), + default=sys.stdout, + help="Where to write the signed JSON. Defaults to stdout.", + ) + + args = parser.parse_args() + + if not args.server_name or not args.signing_key_path: + read_args_from_config(args) + + with open(args.signing_key_path) as f: + key = read_signing_keys(f)[0] + + json_to_sign = args.input_data + if json_to_sign is None: + json_to_sign = args.input.read() + + try: + obj = json.loads(json_to_sign) + except JSONDecodeError as e: + print("Unable to parse input as JSON: %s" % e, file=sys.stderr) + sys.exit(1) + + if not isinstance(obj, dict): + print("Input json was not an object", file=sys.stderr) + sys.exit(1) + + sign_json(obj, args.server_name, key) + for c in json_encoder.iterencode(obj): + args.output.write(c) + args.output.write("\n") + + +def read_args_from_config(args: argparse.Namespace) -> None: + with open(args.config, "r") as fh: + config = yaml.safe_load(fh) + if not args.server_name: + args.server_name = config["server_name"] + if not args.signing_key_path: + args.signing_key_path = config["signing_key_path"] + + +if __name__ == "__main__": + main() diff --git a/scripts-dev/sphinx_api_docs.sh b/scripts-dev/sphinx_api_docs.sh deleted file mode 100644
index ee72b29657..0000000000 --- a/scripts-dev/sphinx_api_docs.sh +++ /dev/null
@@ -1 +0,0 @@ -sphinx-apidoc -o docs/sphinx/ synapse/ -ef diff --git a/scripts/generate_log_config b/scripts/generate_log_config
index b6957f48a3..a13a5634a3 100755 --- a/scripts/generate_log_config +++ b/scripts/generate_log_config
@@ -40,4 +40,6 @@ if __name__ == "__main__": ) args = parser.parse_args() - args.output_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file)) + out = args.output_file + out.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file)) + out.flush() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index ae2887b7d2..5ad17aa90f 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db
@@ -22,6 +22,7 @@ import logging import sys import time import traceback +from typing import Dict, Optional, Set import yaml @@ -39,6 +40,7 @@ from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore +from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore from synapse.storage.databases.main.events_bg_updates import ( EventsBackgroundUpdatesStore, ) @@ -90,6 +92,7 @@ BOOLEAN_COLUMNS = { "room_stats_state": ["is_federatable"], "local_media_repository": ["safe_from_quarantine"], "users": ["shadow_banned"], + "e2e_fallback_keys_json": ["used"], } @@ -151,7 +154,7 @@ IGNORED_TABLES = { # Error returned by the run function. Used at the top-level part of the script to # handle errors and return codes. -end_error = None +end_error = None # type: Optional[str] # The exec_info for the error, if any. If error is defined but not exec_info the script # will show only the error message without the stacktrace, if exec_info is defined but # not the error then the script will show nothing outside of what's printed in the run @@ -172,6 +175,7 @@ class Store( StateBackgroundUpdateStore, MainStateBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore, + EndToEndKeyBackgroundStore, StatsStore, ): def execute(self, f, *args, **kwargs): @@ -288,6 +292,34 @@ class Porter(object): return table, already_ported, total_to_port, forward_chunk, backward_chunk + async def get_table_constraints(self) -> Dict[str, Set[str]]: + """Returns a map of tables that have foreign key constraints to tables they depend on. + """ + + def _get_constraints(txn): + # We can pull the information about foreign key constraints out from + # the postgres schema tables. + sql = """ + SELECT DISTINCT + tc.table_name, + ccu.table_name AS foreign_table_name + FROM + information_schema.table_constraints AS tc + INNER JOIN information_schema.constraint_column_usage AS ccu + USING (table_schema, constraint_name) + WHERE tc.constraint_type = 'FOREIGN KEY'; + """ + txn.execute(sql) + + results = {} + for table, foreign_table in txn: + results.setdefault(table, set()).add(foreign_table) + return results + + return await self.postgres_store.db_pool.runInteraction( + "get_table_constraints", _get_constraints + ) + async def handle_table( self, table, postgres_size, table_size, forward_chunk, backward_chunk ): @@ -489,7 +521,7 @@ class Porter(object): hs = MockHomeserver(self.hs_config) - with make_conn(db_config, engine) as db_conn: + with make_conn(db_config, engine, "portdb") as db_conn: engine.check_database( db_conn, allow_outdated_version=allow_outdated_version ) @@ -587,7 +619,18 @@ class Porter(object): "create_port_table", create_port_table ) - # Step 2. Get tables. + # Step 2. Set up sequences + # + # We do this before porting the tables so that event if we fail half + # way through the postgres DB always have sequences that are greater + # than their respective tables. If we don't then creating the + # `DataStore` object will fail due to the inconsistency. + self.progress.set_state("Setting up sequence generators") + await self._setup_state_group_id_seq() + await self._setup_user_id_seq() + await self._setup_events_stream_seqs() + + # Step 3. Get tables. self.progress.set_state("Fetching tables") sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" @@ -602,7 +645,7 @@ class Porter(object): tables = set(sqlite_tables) & set(postgres_tables) logger.info("Found %d tables", len(tables)) - # Step 3. Figure out what still needs copying + # Step 4. Figure out what still needs copying self.progress.set_state("Checking on port progress") setup_res = await make_deferred_yieldable( defer.gatherResults( @@ -615,26 +658,48 @@ class Porter(object): consumeErrors=True, ) ) - - # Step 4. Do the copying. + # Map from table name to args passed to `handle_table`, i.e. a tuple + # of: `postgres_size`, `table_size`, `forward_chunk`, `backward_chunk`. + tables_to_port_info_map = {r[0]: r[1:] for r in setup_res} + + # Step 5. Do the copying. + # + # This is slightly convoluted as we need to ensure tables are ported + # in the correct order due to foreign key constraints. self.progress.set_state("Copying to postgres") - await make_deferred_yieldable( - defer.gatherResults( - [run_in_background(self.handle_table, *res) for res in setup_res], - consumeErrors=True, + + constraints = await self.get_table_constraints() + tables_ported = set() # type: Set[str] + + while tables_to_port_info_map: + # Pulls out all tables that are still to be ported and which + # only depend on tables that are already ported (if any). + tables_to_port = [ + table + for table in tables_to_port_info_map + if not constraints.get(table, set()) - tables_ported + ] + + await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.handle_table, + table, + *tables_to_port_info_map.pop(table), + ) + for table in tables_to_port + ], + consumeErrors=True, + ) ) - ) - # Step 5. Set up sequences - self.progress.set_state("Setting up sequence generators") - await self._setup_state_group_id_seq() - await self._setup_user_id_seq() - await self._setup_events_stream_seqs() + tables_ported.update(tables_to_port) self.progress.done() except Exception as e: global end_error_exec_info - end_error = e + end_error = str(e) end_error_exec_info = sys.exc_info() logger.exception("") finally: @@ -788,45 +853,62 @@ class Porter(object): return done, remaining + done - def _setup_state_group_id_seq(self): + async def _setup_state_group_id_seq(self): + curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True + ) + + if not curr_id: + return + def r(txn): - txn.execute("SELECT MAX(id) FROM state_groups") - curr_id = txn.fetchone()[0] - if not curr_id: - return next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r) + await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r) + + async def _setup_user_id_seq(self): + curr_id = await self.sqlite_store.db_pool.runInteraction( + "setup_user_id_seq", find_max_generated_user_id_localpart + ) - def _setup_user_id_seq(self): def r(txn): - next_id = find_max_generated_user_id_localpart(txn) + 1 + next_id = curr_id + 1 txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) - def _setup_events_stream_seqs(self): - def r(txn): - txn.execute("SELECT MAX(stream_ordering) FROM events") - curr_id = txn.fetchone()[0] - if curr_id: - next_id = curr_id + 1 - txn.execute( - "ALTER SEQUENCE events_stream_seq RESTART WITH %s", (next_id,) - ) + async def _setup_events_stream_seqs(self): + """Set the event stream sequences to the correct values. + """ + + # We get called before we've ported the events table, so we need to + # fetch the current positions from the SQLite store. + curr_forward_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="events", keyvalues={}, retcol="MAX(stream_ordering)", allow_none=True + ) - txn.execute("SELECT -MIN(stream_ordering) FROM events") - curr_id = txn.fetchone()[0] - if curr_id: - next_id = curr_id + 1 + curr_backward_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="events", + keyvalues={}, + retcol="MAX(-MIN(stream_ordering), 1)", + allow_none=True, + ) + + def _setup_events_stream_seqs_set_pos(txn): + if curr_forward_id: txn.execute( - "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s", - (next_id,), + "ALTER SEQUENCE events_stream_seq RESTART WITH %s", + (curr_forward_id + 1,), ) - return self.postgres_store.db_pool.runInteraction( - "_setup_events_stream_seqs", r + txn.execute( + "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s", + (curr_backward_id + 1,), + ) + + return await self.postgres_store.db_pool.runInteraction( + "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos, ) diff --git a/setup.cfg b/setup.cfg
index a32278ea8a..f46e43fad0 100644 --- a/setup.cfg +++ b/setup.cfg
@@ -1,8 +1,3 @@ -[build_sphinx] -source-dir = docs/sphinx -build-dir = docs/build -all_files = 1 - [trial] test_suite = tests diff --git a/setup.py b/setup.py
index 926b1bc86f..9730afb41b 100755 --- a/setup.py +++ b/setup.py
@@ -15,12 +15,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import glob import os -from setuptools import setup, find_packages, Command -import sys +from setuptools import Command, find_packages, setup here = os.path.abspath(os.path.dirname(__file__)) @@ -104,6 +102,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [ "flake8", ] +CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"] + # Dependencies which are exclusively required by unit test code. This is # NOT a list of all modules that are necessary to run the unit tests. # Tests assume that all optional dependencies are installed. @@ -131,6 +131,7 @@ setup( "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], scripts=["synctl"] + glob.glob("scripts/*"), cmdclass={"test": TestCommand}, diff --git a/stubs/sortedcontainers/__init__.pyi b/stubs/sortedcontainers/__init__.pyi
index 073b806d3c..fa307483fe 100644 --- a/stubs/sortedcontainers/__init__.pyi +++ b/stubs/sortedcontainers/__init__.pyi
@@ -1,13 +1,12 @@ -from .sorteddict import ( - SortedDict, - SortedKeysView, - SortedItemsView, - SortedValuesView, -) +from .sorteddict import SortedDict, SortedItemsView, SortedKeysView, SortedValuesView +from .sortedlist import SortedKeyList, SortedList, SortedListWithKey __all__ = [ "SortedDict", "SortedKeysView", "SortedItemsView", "SortedValuesView", + "SortedKeyList", + "SortedList", + "SortedListWithKey", ] diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi new file mode 100644
index 0000000000..8f6086b3ff --- /dev/null +++ b/stubs/sortedcontainers/sortedlist.pyi
@@ -0,0 +1,177 @@ +# stub for SortedList. This is an exact copy of +# https://github.com/grantjenks/python-sortedcontainers/blob/a419ffbd2b1c935b09f11f0971696e537fd0c510/sortedcontainers/sortedlist.pyi +# (from https://github.com/grantjenks/python-sortedcontainers/pull/107) + +from typing import ( + Any, + Callable, + Generic, + Iterable, + Iterator, + List, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +_T = TypeVar("_T") +_SL = TypeVar("_SL", bound=SortedList) +_SKL = TypeVar("_SKL", bound=SortedKeyList) +_Key = Callable[[_T], Any] +_Repr = Callable[[], str] + +def recursive_repr(fillvalue: str = ...) -> Callable[[_Repr], _Repr]: ... + +class SortedList(MutableSequence[_T]): + + DEFAULT_LOAD_FACTOR: int = ... + def __init__( + self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ..., + ): ... + # NB: currently mypy does not honour return type, see mypy #3307 + @overload + def __new__(cls: Type[_SL], iterable: None, key: None) -> _SL: ... + @overload + def __new__(cls: Type[_SL], iterable: None, key: _Key[_T]) -> SortedKeyList[_T]: ... + @overload + def __new__(cls: Type[_SL], iterable: Iterable[_T], key: None) -> _SL: ... + @overload + def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ... + @property + def key(self) -> Optional[Callable[[_T], Any]]: ... + def _reset(self, load: int) -> None: ... + def clear(self) -> None: ... + def _clear(self) -> None: ... + def add(self, value: _T) -> None: ... + def _expand(self, pos: int) -> None: ... + def update(self, iterable: Iterable[_T]) -> None: ... + def _update(self, iterable: Iterable[_T]) -> None: ... + def discard(self, value: _T) -> None: ... + def remove(self, value: _T) -> None: ... + def _delete(self, pos: int, idx: int) -> None: ... + def _loc(self, pos: int, idx: int) -> int: ... + def _pos(self, idx: int) -> int: ... + def _build_index(self) -> None: ... + def __contains__(self, value: Any) -> bool: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> List[_T]: ... + @overload + def _getitem(self, index: int) -> _T: ... + @overload + def _getitem(self, index: slice) -> List[_T]: ... + @overload + def __setitem__(self, index: int, value: _T) -> None: ... + @overload + def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ... + def __iter__(self) -> Iterator[_T]: ... + def __reversed__(self) -> Iterator[_T]: ... + def __len__(self) -> int: ... + def reverse(self) -> None: ... + def islice( + self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, + ) -> Iterator[_T]: ... + def _islice( + self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool, + ) -> Iterator[_T]: ... + def irange( + self, + minimum: Optional[int] = ..., + maximum: Optional[int] = ..., + inclusive: Tuple[bool, bool] = ..., + reverse: bool = ..., + ) -> Iterator[_T]: ... + def bisect_left(self, value: _T) -> int: ... + def bisect_right(self, value: _T) -> int: ... + def bisect(self, value: _T) -> int: ... + def _bisect_right(self, value: _T) -> int: ... + def count(self, value: _T) -> int: ... + def copy(self: _SL) -> _SL: ... + def __copy__(self: _SL) -> _SL: ... + def append(self, value: _T) -> None: ... + def extend(self, values: Iterable[_T]) -> None: ... + def insert(self, index: int, value: _T) -> None: ... + def pop(self, index: int = ...) -> _T: ... + def index( + self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + ) -> int: ... + def __add__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __iadd__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __mul__(self: _SL, num: int) -> _SL: ... + def __rmul__(self: _SL, num: int) -> _SL: ... + def __imul__(self: _SL, num: int) -> _SL: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __lt__(self, other: Sequence[_T]) -> bool: ... + def __gt__(self, other: Sequence[_T]) -> bool: ... + def __le__(self, other: Sequence[_T]) -> bool: ... + def __ge__(self, other: Sequence[_T]) -> bool: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + +class SortedKeyList(SortedList[_T]): + def __init__( + self, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... + ) -> None: ... + def __new__( + cls, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... + ) -> SortedKeyList[_T]: ... + @property + def key(self) -> Callable[[_T], Any]: ... + def clear(self) -> None: ... + def _clear(self) -> None: ... + def add(self, value: _T) -> None: ... + def _expand(self, pos: int) -> None: ... + def update(self, iterable: Iterable[_T]) -> None: ... + def _update(self, iterable: Iterable[_T]) -> None: ... + # NB: Must be T to be safely passed to self.func, yet base class imposes Any + def __contains__(self, value: _T) -> bool: ... # type: ignore + def discard(self, value: _T) -> None: ... + def remove(self, value: _T) -> None: ... + def _delete(self, pos: int, idx: int) -> None: ... + def irange( + self, + minimum: Optional[int] = ..., + maximum: Optional[int] = ..., + inclusive: Tuple[bool, bool] = ..., + reverse: bool = ..., + ): ... + def irange_key( + self, + min_key: Optional[Any] = ..., + max_key: Optional[Any] = ..., + inclusive: Tuple[bool, bool] = ..., + reserve: bool = ..., + ): ... + def bisect_left(self, value: _T) -> int: ... + def bisect_right(self, value: _T) -> int: ... + def bisect(self, value: _T) -> int: ... + def bisect_key_left(self, key: Any) -> int: ... + def _bisect_key_left(self, key: Any) -> int: ... + def bisect_key_right(self, key: Any) -> int: ... + def _bisect_key_right(self, key: Any) -> int: ... + def bisect_key(self, key: Any) -> int: ... + def count(self, value: _T) -> int: ... + def copy(self: _SKL) -> _SKL: ... + def __copy__(self: _SKL) -> _SKL: ... + def index( + self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + ) -> int: ... + def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __iadd__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __mul__(self: _SKL, num: int) -> _SKL: ... + def __rmul__(self: _SKL, num: int) -> _SKL: ... + def __imul__(self: _SKL, num: int) -> _SKL: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + +SortedListWithKey = SortedKeyList diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index c66413f003..522244bb57 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi
@@ -16,7 +16,7 @@ """Contains *incomplete* type hints for txredisapi. """ -from typing import List, Optional, Union +from typing import List, Optional, Union, Type class RedisProtocol: def publish(self, channel: str, message: bytes): ... @@ -42,3 +42,21 @@ def lazyConnection( class SubscriberFactory: def buildProtocol(self, addr): ... + +class ConnectionHandler: ... + +class RedisFactory: + continueTrying: bool + handler: RedisProtocol + def __init__( + self, + uuid: str, + dbid: Optional[int], + poolsize: int, + isLazy: bool = False, + handler: Type = ConnectionHandler, + charset: str = "utf-8", + password: Optional[str] = None, + replyTimeout: Optional[int] = None, + convertNumbers: Optional[int] = True, + ): ... diff --git a/synapse/__init__.py b/synapse/__init__.py
index 722b53a67d..193adca624 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.21.1" +__version__ = "1.25.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index da0996edbc..dfe26dea6d 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py
@@ -37,7 +37,7 @@ def request_registration( exit=sys.exit, ): - url = "%s/_matrix/client/r0/admin/register" % (server_location,) + url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),) # Get the nonce r = requests.get(url, verify=False) diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1071a0576e..48c4d7b0be 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py
@@ -23,7 +23,7 @@ from twisted.web.server import Request import synapse.types from synapse import event_auth from synapse.api.auth_blocking import AuthBlocking -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.errors import ( AuthError, Codes, @@ -31,10 +31,12 @@ from synapse.api.errors import ( MissingClientTokenError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.appservice import ApplicationService from synapse.events import EventBase +from synapse.http.site import SynapseRequest from synapse.logging import opentracing as opentracing +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import StateMap, UserID -from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -70,8 +72,9 @@ class Auth: self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache(10000) - register_cache("cache", "token_cache", self.token_cache) + self.token_cache = LruCache( + 10000, "token_cache" + ) # type: LruCache[str, Tuple[str, bool]] self._auth_blocking = AuthBlocking(self.hs) @@ -184,18 +187,12 @@ class Auth: """ try: ip_addr = self.hs.get_ip_from_request(request) - user_agent = request.requestHeaders.getRawHeaders( - b"User-Agent", default=[b""] - )[0].decode("ascii", "surrogateescape") + user_agent = request.get_user_agent("") access_token = self.get_access_token_from_request(request) user_id, app_service = await self._get_appservice_user_id(request) if user_id: - request.authenticated_entity = user_id - opentracing.set_tag("authenticated_entity", user_id) - opentracing.set_tag("appservice_id", app_service.id) - if ip_addr and self._track_appservice_user_ips: await self.store.insert_client_ip( user_id=user_id, @@ -205,31 +202,38 @@ class Auth: device_id="dummy-device", # stubbed ) - return synapse.types.create_requester(user_id, app_service=app_service) + requester = synapse.types.create_requester( + user_id, app_service=app_service + ) + + request.requester = user_id + opentracing.set_tag("authenticated_entity", user_id) + opentracing.set_tag("user_id", user_id) + opentracing.set_tag("appservice_id", app_service.id) + + return requester user_info = await self.get_user_by_access_token( access_token, rights, allow_expired=allow_expired ) - user = user_info["user"] - token_id = user_info["token_id"] - is_guest = user_info["is_guest"] - shadow_banned = user_info["shadow_banned"] + token_id = user_info.token_id + is_guest = user_info.is_guest + shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: - user_id = user.to_string() - if await self.store.is_account_expired(user_id, self.clock.time_msec()): + if await self.store.is_account_expired( + user_info.user_id, self.clock.time_msec() + ): raise AuthError( 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) - # device_id may not be present if get_user_by_access_token has been - # stubbed out. - device_id = user_info.get("device_id") + device_id = user_info.device_id - if user and access_token and ip_addr: + if access_token and ip_addr: await self.store.insert_client_ip( - user_id=user.to_string(), + user_id=user_info.token_owner, access_token=access_token, ip=ip_addr, user_agent=user_agent, @@ -243,19 +247,23 @@ class Auth: errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) - request.authenticated_entity = user.to_string() - opentracing.set_tag("authenticated_entity", user.to_string()) - if device_id: - opentracing.set_tag("device_id", device_id) - - return synapse.types.create_requester( - user, + requester = synapse.types.create_requester( + user_info.user_id, token_id, is_guest, shadow_banned, device_id, app_service=app_service, + authenticated_entity=user_info.token_owner, ) + + request.requester = requester + opentracing.set_tag("authenticated_entity", user_info.token_owner) + opentracing.set_tag("user_id", user_info.user_id) + if device_id: + opentracing.set_tag("device_id", device_id) + + return requester except KeyError: raise MissingClientTokenError() @@ -286,7 +294,7 @@ class Auth: async def get_user_by_access_token( self, token: str, rights: str = "access", allow_expired: bool = False, - ) -> dict: + ) -> TokenLookupResult: """ Validate access token and get user_id from it Args: @@ -295,13 +303,7 @@ class Auth: allow this allow_expired: If False, raises an InvalidClientTokenError if the token is expired - Returns: - dict that includes: - `user` (UserID) - `is_guest` (bool) - `shadow_banned` (bool) - `token_id` (int|None): access token id. May be None if guest - `device_id` (str|None): device corresponding to access token + Raises: InvalidClientTokenError if a user by that token exists, but the token is expired @@ -311,9 +313,9 @@ class Auth: if rights == "access": # first look in the database - r = await self._look_up_user_by_access_token(token) + r = await self.store.get_user_by_access_token(token) if r: - valid_until_ms = r["valid_until_ms"] + valid_until_ms = r.valid_until_ms if ( not allow_expired and valid_until_ms is not None @@ -330,7 +332,6 @@ class Auth: # otherwise it needs to be a valid macaroon try: user_id, guest = self._parse_and_validate_macaroon(token, rights) - user = UserID.from_string(user_id) if rights == "access": if not guest: @@ -356,23 +357,17 @@ class Auth: raise InvalidClientTokenError( "Guest access token used for regular user" ) - ret = { - "user": user, - "is_guest": True, - "shadow_banned": False, - "token_id": None, + + ret = TokenLookupResult( + user_id=user_id, + is_guest=True, # all guests get the same device id - "device_id": GUEST_DEVICE_ID, - } + device_id=GUEST_DEVICE_ID, + ) elif rights == "delete_pusher": # We don't store these tokens in the database - ret = { - "user": user, - "is_guest": False, - "shadow_banned": False, - "token_id": None, - "device_id": None, - } + + ret = TokenLookupResult(user_id=user_id, is_guest=False) else: raise RuntimeError("Unknown rights setting %s", rights) return ret @@ -481,31 +476,15 @@ class Auth: now = self.hs.get_clock().time_msec() return now < expiry - async def _look_up_user_by_access_token(self, token): - ret = await self.store.get_user_by_access_token(token) - if not ret: - return None - - # we use ret.get() below because *lots* of unit tests stub out - # get_user_by_access_token in a way where it only returns a couple of - # the fields. - user_info = { - "user": UserID.from_string(ret.get("name")), - "token_id": ret.get("token_id", None), - "is_guest": False, - "shadow_banned": ret.get("shadow_banned"), - "device_id": ret.get("device_id"), - "valid_until_ms": ret.get("valid_until_ms"), - } - return user_info - - def get_appservice_by_req(self, request): + def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService: token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) if not service: logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() - request.authenticated_entity = service.sender + request.requester = synapse.types.create_requester( + service.sender, app_service=service + ) return service async def is_server_admin(self, user: UserID) -> bool: @@ -669,7 +648,8 @@ class Auth: ) if ( visibility - and visibility.content["history_visibility"] == "world_readable" + and visibility.content.get("history_visibility") + == HistoryVisibility.WORLD_READABLE ): return Membership.JOIN, None raise AuthError( diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index d8fafd7cb8..d8088f524a 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py
@@ -14,10 +14,12 @@ # limitations under the License. import logging +from typing import Optional from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.errors import Codes, ResourceLimitError from synapse.config.server import is_threepid_reserved +from synapse.types import Requester logger = logging.getLogger(__name__) @@ -33,24 +35,54 @@ class AuthBlocking: self._max_mau_value = hs.config.max_mau_value self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids + self._server_name = hs.hostname + self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips - async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): + async def check_auth_blocking( + self, + user_id: Optional[str] = None, + threepid: Optional[dict] = None, + user_type: Optional[str] = None, + requester: Optional[Requester] = None, + ): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag Args: - user_id(str|None): If present, checks for presence against existing + user_id: If present, checks for presence against existing MAU cohort - threepid(dict|None): If present, checks for presence against configured + threepid: If present, checks for presence against configured reserved threepid. Used in cases where the user is trying register with a MAU blocked server, normally they would be rejected but their threepid is on the reserved list. user_id and threepid should never be set at the same time. - user_type(str|None): If present, is used to decide whether to check against + user_type: If present, is used to decide whether to check against certain blocking reasons like MAU. + + requester: If present, and the authenticated entity is a user, checks for + presence against existing MAU cohort. Passing in both a `user_id` and + `requester` is an error. """ + if requester and user_id: + raise Exception( + "Passed in both 'user_id' and 'requester' to 'check_auth_blocking'" + ) + + if requester: + if requester.authenticated_entity.startswith("@"): + user_id = requester.authenticated_entity + elif requester.authenticated_entity == self._server_name: + # We never block the server from doing actions on behalf of + # users. + return + elif requester.app_service and not self._track_appservice_user_ips: + # If we're authenticated as an appservice then we only block + # auth if `track_appservice_user_ips` is set, as that option + # implicitly means that application services are part of MAU + # limits. + return # Never fail an auth check for the server notices users or support user # This can be a problem where event creation is prohibited due to blocking diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 46013cde15..565a8cd76a 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py
@@ -95,6 +95,8 @@ class EventTypes: Presence = "m.presence" + Dummy = "org.matrix.dummy_event" + class RejectedReason: AUTH_ERROR = "auth_error" @@ -155,3 +157,15 @@ class EventContentFields: class RoomEncryptionAlgorithms: MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" DEFAULT = MEGOLM_V1_AES_SHA2 + + +class AccountDataTypes: + DIRECT = "m.direct" + IGNORED_USER_LIST = "m.ignored_user_list" + + +class HistoryVisibility: + INVITED = "invited" + JOINED = "joined" + SHARED = "shared" + WORLD_READABLE = "world_readable" diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index fb476ddaf5..37ecdbe3d8 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py
@@ -28,9 +28,11 @@ from twisted.protocols.tls import TLSMemoryBIOFactory import synapse from synapse.app import check_bind_error +from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config.server import ListenerConfig from synapse.crypto import context_factory from synapse.logging.context import PreserveLoggingContext +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.util.async_helpers import Linearizer from synapse.util.daemonize import daemonize_process from synapse.util.rlimit import change_resource_limit @@ -48,7 +50,6 @@ def register_sighup(func, *args, **kwargs): Args: func (function): Function to be called when sent a SIGHUP signal. - Will be called with a single default argument, the homeserver. *args, **kwargs: args and kwargs to be passed to the target function. """ _sighup_callbacks.append((func, args, kwargs)) @@ -244,19 +245,30 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): # Set up the SIGHUP machinery. if hasattr(signal, "SIGHUP"): + reactor = hs.get_reactor() + + @wrap_as_background_process("sighup") def handle_sighup(*args, **kwargs): # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. sdnotify(b"RELOADING=1") for i, args, kwargs in _sighup_callbacks: - i(hs, *args, **kwargs) + i(*args, **kwargs) sdnotify(b"READY=1") - signal.signal(signal.SIGHUP, handle_sighup) + # We defer running the sighup handlers until next reactor tick. This + # is so that we're in a sane state, e.g. flushing the logs may fail + # if the sighup happens in the middle of writing a log entry. + def run_sighup(*args, **kwargs): + # `callFromThread` should be "signal safe" as well as thread + # safe. + reactor.callFromThread(handle_sighup, *args, **kwargs) + + signal.signal(signal.SIGHUP, run_sighup) - register_sighup(refresh_certificate) + register_sighup(refresh_certificate, hs) # Load the certificate from disk. refresh_certificate(hs) @@ -271,9 +283,19 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): hs.get_datastore().db_pool.start_profiling() hs.get_pusherpool().start() + # Log when we start the shut down process. + hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", logger.info, "Shutting down..." + ) + setup_sentry(hs) setup_sdnotify(hs) + # If background tasks are running on the main process, start collecting the + # phone home stats. + if hs.config.run_background_tasks: + start_phone_stats_home(hs) + # We now freeze all allocated objects in the hopes that (almost) # everything currently allocated are things that will be used for the # rest of time. Doing so means less work each GC (hopefully). diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 7d309b1bb0..b4bd4d8e7a 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py
@@ -89,7 +89,7 @@ async def export_data_command(hs, args): user_id = args.user_id directory = args.output_directory - res = await hs.get_handlers().admin_handler.export_user_data( + res = await hs.get_admin_handler().export_user_data( user_id, FileExfiltrationWriter(user_id, directory=directory) ) print(res) @@ -208,6 +208,7 @@ def start(config_options): # Explicitly disable background processes config.update_user_directory = False + config.run_background_tasks = False config.start_pushers = False config.send_federation = False diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index c38413c893..fa23d9bb20 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -89,7 +89,7 @@ from synapse.replication.tcp.streams import ( ToDeviceStream, ) from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client.v1 import events +from synapse.rest.client.v1 import events, room from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.login import LoginRestServlet from synapse.rest.client.v1.profile import ( @@ -98,20 +98,6 @@ from synapse.rest.client.v1.profile import ( ProfileRestServlet, ) from synapse.rest.client.v1.push_rule import PushRuleRestServlet -from synapse.rest.client.v1.room import ( - JoinedRoomMemberListRestServlet, - JoinRoomAliasServlet, - PublicRoomListRestServlet, - RoomEventContextServlet, - RoomInitialSyncRestServlet, - RoomMemberListRestServlet, - RoomMembershipRestServlet, - RoomMessageListRestServlet, - RoomSendEventRestServlet, - RoomStateEventRestServlet, - RoomStateRestServlet, - RoomTypingRestServlet, -) from synapse.rest.client.v1.voip import VoipRestServlet from synapse.rest.client.v2_alpha import groups, sync, user_directory from synapse.rest.client.v2_alpha._base import client_patterns @@ -127,12 +113,16 @@ from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer, cache_in_self from synapse.storage.databases.main.censor_events import CensorEventsStore +from synapse.storage.databases.main.client_ips import ClientIpWorkerStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore +from synapse.storage.databases.main.metrics import ServerMetricsStore from synapse.storage.databases.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.databases.main.presence import UserPresenceState from synapse.storage.databases.main.search import SearchWorkerStore +from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt @@ -262,7 +252,6 @@ class GenericWorkerPresence(BasePresenceHandler): super().__init__(hs) self.hs = hs self.is_mine_id = hs.is_mine_id - self.http_client = hs.get_simple_http_client() self._presence_enabled = hs.config.use_presence @@ -454,6 +443,7 @@ class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. UserDirectoryStore, + StatsStore, UIAuthWorkerStore, SlavedDeviceInboxStore, SlavedDeviceStore, @@ -463,6 +453,7 @@ class GenericWorkerSlavedStore( SlavedAccountDataStore, SlavedPusherStore, CensorEventsStore, + ClientIpWorkerStore, SlavedEventStore, SlavedKeyStore, RoomStore, @@ -476,7 +467,9 @@ class GenericWorkerSlavedStore( SlavedFilteringStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, + ServerMetricsStore, SearchWorkerStore, + TransactionWorkerStore, BaseSlavedStore, ): pass @@ -505,12 +498,6 @@ class GenericWorkerServer(HomeServer): elif name == "client": resource = JsonResource(self, canonical_json=False) - PublicRoomListRestServlet(self).register(resource) - RoomMemberListRestServlet(self).register(resource) - JoinedRoomMemberListRestServlet(self).register(resource) - RoomStateRestServlet(self).register(resource) - RoomEventContextServlet(self).register(resource) - RoomMessageListRestServlet(self).register(resource) RegisterRestServlet(self).register(resource) LoginRestServlet(self).register(resource) ThreepidRestServlet(self).register(resource) @@ -519,22 +506,19 @@ class GenericWorkerServer(HomeServer): VoipRestServlet(self).register(resource) PushRuleRestServlet(self).register(resource) VersionsRestServlet(self).register(resource) - RoomSendEventRestServlet(self).register(resource) - RoomMembershipRestServlet(self).register(resource) - RoomStateEventRestServlet(self).register(resource) - JoinRoomAliasServlet(self).register(resource) + ProfileAvatarURLRestServlet(self).register(resource) ProfileDisplaynameRestServlet(self).register(resource) ProfileRestServlet(self).register(resource) KeyUploadServlet(self).register(resource) AccountDataServlet(self).register(resource) RoomAccountDataServlet(self).register(resource) - RoomTypingRestServlet(self).register(resource) sync.register_servlets(self, resource) events.register_servlets(self, resource) + room.register_servlets(self, resource, True) + room.register_deprecated_servlets(self, resource) InitialSyncRestServlet(self).register(resource) - RoomInitialSyncRestServlet(self).register(resource) user_directory.register_servlets(self, resource) @@ -782,10 +766,6 @@ class FederationSenderHandler: send_queue.process_rows_for_federation(self.federation_sender, rows) await self.update_token(token) - # We also need to poke the federation sender when new events happen - elif stream_name == "events": - self.federation_sender.notify_new_events(token) - # ... and when new receipts happen elif stream_name == ReceiptsStream.NAME: await self._on_new_receipts(rows) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index dff739e106..8d9b53be53 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py
@@ -17,13 +17,9 @@ import gc import logging -import math import os -import resource import sys -from typing import Iterable - -from prometheus_client import Gauge +from typing import Iterable, Iterator from twisted.application import service from twisted.internet import defer, reactor @@ -60,8 +56,6 @@ from synapse.http.server import ( from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.module_api import ModuleApi from synapse.python_dependencies import check_requirements from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -69,6 +63,7 @@ from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -96,7 +91,7 @@ class SynapseHomeServer(HomeServer): tls = listener_config.tls site_tag = listener_config.http_options.tag if site_tag is None: - site_tag = port + site_tag = str(port) # We always include a health resource. resources = {"/health": HealthResource()} @@ -111,9 +106,12 @@ class SynapseHomeServer(HomeServer): additional_resources = listener_config.http_options.additional_resources logger.debug("Configuring additional resources: %r", additional_resources) - module_api = ModuleApi(self, self.get_auth_handler()) + module_api = self.get_module_api() for path, resmodule in additional_resources.items(): - handler_cls, config = load_module(resmodule) + handler_cls, config = load_module( + resmodule, + ("listeners", site_tag, "additional_resources", "<%s>" % (path,)), + ) handler = handler_cls(config, module_api) if IResource.providedBy(handler): resource = handler @@ -195,6 +193,7 @@ class SynapseHomeServer(HomeServer): "/_matrix/client/versions": client_resource, "/.well-known/matrix/client": WellKnownResource(self), "/_synapse/admin": AdminRestResource(self), + "/_synapse/client/pick_username": pick_username_resource(self), } ) @@ -334,20 +333,6 @@ class SynapseHomeServer(HomeServer): logger.warning("Unrecognized listener type: %s", listener.type) -# Gauges to expose monthly active user control metrics -current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") -current_mau_by_service_gauge = Gauge( - "synapse_admin_mau_current_mau_by_service", - "Current MAU by service", - ["app_service"], -) -max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") -registered_reserved_users_mau_gauge = Gauge( - "synapse_admin_mau:registered_reserved_users", - "Registered users with reserved threepids", -) - - def setup(config_options): """ Args: @@ -362,7 +347,10 @@ def setup(config_options): "Synapse Homeserver", config_options ) except ConfigError as e: - sys.stderr.write("\nERROR: %s\n" % (e,)) + sys.stderr.write("\n") + for f in format_config_error(e): + sys.stderr.write(f) + sys.stderr.write("\n") sys.exit(1) if not config: @@ -389,8 +377,6 @@ def setup(config_options): except UpgradeDatabaseException as e: quit_with_error("Failed to upgrade database: %s" % (e,)) - hs.setup_master() - async def do_acme() -> bool: """ Reprovision an ACME certificate, if it's required. @@ -467,6 +453,38 @@ def setup(config_options): return hs +def format_config_error(e: ConfigError) -> Iterator[str]: + """ + Formats a config error neatly + + The idea is to format the immediate error, plus the "causes" of those errors, + hopefully in a way that makes sense to the user. For example: + + Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template': + Failed to parse config for module 'JinjaOidcMappingProvider': + invalid jinja template: + unexpected end of template, expected 'end of print statement'. + + Args: + e: the error to be formatted + + Returns: An iterator which yields string fragments to be formatted + """ + yield "Error in configuration" + + if e.path: + yield " at '%s'" % (".".join(e.path),) + + yield ":\n %s" % (e.msg,) + + e = e.__cause__ + indent = 1 + while e: + indent += 1 + yield ":\n%s%s" % (" " * indent, str(e)) + e = e.__cause__ + + class SynapseService(service.Service): """ A twisted Service class that will start synapse. Used to run synapse @@ -486,92 +504,6 @@ class SynapseService(service.Service): return self._port.stopListening() -# Contains the list of processes we will be monitoring -# currently either 0 or 1 -_stats_process = [] - - -async def phone_stats_home(hs, stats, stats_process=_stats_process): - logger.info("Gathering stats for reporting") - now = int(hs.get_clock().time()) - uptime = int(now - hs.start_time) - if uptime < 0: - uptime = 0 - - # - # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test. - # - old = stats_process[0] - new = (now, resource.getrusage(resource.RUSAGE_SELF)) - stats_process[0] = new - - # Get RSS in bytes - stats["memory_rss"] = new[1].ru_maxrss - - # Get CPU time in % of a single core, not % of all cores - used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( - old[1].ru_utime + old[1].ru_stime - ) - if used_cpu_time == 0 or new[0] == old[0]: - stats["cpu_average"] = 0 - else: - stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) - - # - # General statistics - # - - stats["homeserver"] = hs.config.server_name - stats["server_context"] = hs.config.server_context - stats["timestamp"] = now - stats["uptime_seconds"] = uptime - version = sys.version_info - stats["python_version"] = "{}.{}.{}".format( - version.major, version.minor, version.micro - ) - stats["total_users"] = await hs.get_datastore().count_all_users() - - total_nonbridged_users = await hs.get_datastore().count_nonbridged_users() - stats["total_nonbridged_users"] = total_nonbridged_users - - daily_user_type_results = await hs.get_datastore().count_daily_user_type() - for name, count in daily_user_type_results.items(): - stats["daily_user_type_" + name] = count - - room_count = await hs.get_datastore().get_room_count() - stats["total_room_count"] = room_count - - stats["daily_active_users"] = await hs.get_datastore().count_daily_users() - stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() - stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() - stats["daily_messages"] = await hs.get_datastore().count_daily_messages() - - r30_results = await hs.get_datastore().count_r30_users() - for name, count in r30_results.items(): - stats["r30_users_" + name] = count - - daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() - stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = hs.config.caches.global_factor - stats["event_cache_size"] = hs.config.caches.event_cache_size - - # - # Database version - # - - # This only reports info about the *main* database. - stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__ - stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version - - logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) - try: - await hs.get_proxied_http_client().put_json( - hs.config.report_stats_endpoint, stats - ) - except Exception as e: - logger.warning("Error reporting stats: %s", e) - - def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: @@ -597,81 +529,6 @@ def run(hs): ThreadPool._worker = profile(ThreadPool._worker) reactor.run = profile(reactor.run) - clock = hs.get_clock() - - stats = {} - - def performance_stats_init(): - _stats_process.clear() - _stats_process.append( - (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) - ) - - def start_phone_stats_home(): - return run_as_background_process( - "phone_stats_home", phone_stats_home, hs, stats - ) - - def generate_user_daily_visit_stats(): - return run_as_background_process( - "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits - ) - - # Rather than update on per session basis, batch up the requests. - # If you increase the loop period, the accuracy of user_daily_visits - # table will decrease - clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) - - # monthly active user limiting functionality - def reap_monthly_active_users(): - return run_as_background_process( - "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users - ) - - clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) - reap_monthly_active_users() - - async def generate_monthly_active_users(): - current_mau_count = 0 - current_mau_count_by_service = {} - reserved_users = () - store = hs.get_datastore() - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: - current_mau_count = await store.get_monthly_active_count() - current_mau_count_by_service = ( - await store.get_monthly_active_count_by_service() - ) - reserved_users = await store.get_registered_reserved_users() - current_mau_gauge.set(float(current_mau_count)) - - for app_service, count in current_mau_count_by_service.items(): - current_mau_by_service_gauge.labels(app_service).set(float(count)) - - registered_reserved_users_mau_gauge.set(float(len(reserved_users))) - max_mau_gauge.set(float(hs.config.max_mau_value)) - - def start_generate_monthly_active_users(): - return run_as_background_process( - "generate_monthly_active_users", generate_monthly_active_users - ) - - start_generate_monthly_active_users() - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: - clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000) - # End of monthly active user settings - - if hs.config.report_stats: - logger.info("Scheduling stats reporting for 3 hour intervals") - clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) - - # We need to defer this init for the cases that we daemonize - # otherwise the process ID we get is that of the non-daemon process - clock.call_later(0, performance_stats_init) - - # We wait 5 minutes to send the first set of stats as the server can - # be quite busy the first few minutes - clock.call_later(5 * 60, start_phone_stats_home) - _base.start_reactor( "synapse-homeserver", soft_file_limit=hs.config.soft_file_limit, diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py new file mode 100644
index 0000000000..c38cf8231f --- /dev/null +++ b/synapse/app/phone_stats_home.py
@@ -0,0 +1,190 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import math +import resource +import sys + +from prometheus_client import Gauge + +from synapse.metrics.background_process_metrics import wrap_as_background_process + +logger = logging.getLogger("synapse.app.homeserver") + +# Contains the list of processes we will be monitoring +# currently either 0 or 1 +_stats_process = [] + +# Gauges to expose monthly active user control metrics +current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") +current_mau_by_service_gauge = Gauge( + "synapse_admin_mau_current_mau_by_service", + "Current MAU by service", + ["app_service"], +) +max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") +registered_reserved_users_mau_gauge = Gauge( + "synapse_admin_mau:registered_reserved_users", + "Registered users with reserved threepids", +) + + +@wrap_as_background_process("phone_stats_home") +async def phone_stats_home(hs, stats, stats_process=_stats_process): + logger.info("Gathering stats for reporting") + now = int(hs.get_clock().time()) + uptime = int(now - hs.start_time) + if uptime < 0: + uptime = 0 + + # + # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test. + # + old = stats_process[0] + new = (now, resource.getrusage(resource.RUSAGE_SELF)) + stats_process[0] = new + + # Get RSS in bytes + stats["memory_rss"] = new[1].ru_maxrss + + # Get CPU time in % of a single core, not % of all cores + used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( + old[1].ru_utime + old[1].ru_stime + ) + if used_cpu_time == 0 or new[0] == old[0]: + stats["cpu_average"] = 0 + else: + stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) + + # + # General statistics + # + + stats["homeserver"] = hs.config.server_name + stats["server_context"] = hs.config.server_context + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + version = sys.version_info + stats["python_version"] = "{}.{}.{}".format( + version.major, version.minor, version.micro + ) + stats["total_users"] = await hs.get_datastore().count_all_users() + + total_nonbridged_users = await hs.get_datastore().count_nonbridged_users() + stats["total_nonbridged_users"] = total_nonbridged_users + + daily_user_type_results = await hs.get_datastore().count_daily_user_type() + for name, count in daily_user_type_results.items(): + stats["daily_user_type_" + name] = count + + room_count = await hs.get_datastore().get_room_count() + stats["total_room_count"] = room_count + + stats["daily_active_users"] = await hs.get_datastore().count_daily_users() + stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() + stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() + stats["daily_messages"] = await hs.get_datastore().count_daily_messages() + + r30_results = await hs.get_datastore().count_r30_users() + for name, count in r30_results.items(): + stats["r30_users_" + name] = count + + daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages + stats["cache_factor"] = hs.config.caches.global_factor + stats["event_cache_size"] = hs.config.caches.event_cache_size + + # + # Database version + # + + # This only reports info about the *main* database. + stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version + + # + # Logging configuration + # + synapse_logger = logging.getLogger("synapse") + log_level = synapse_logger.getEffectiveLevel() + stats["log_level"] = logging.getLevelName(log_level) + + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) + try: + await hs.get_proxied_http_client().put_json( + hs.config.report_stats_endpoint, stats + ) + except Exception as e: + logger.warning("Error reporting stats: %s", e) + + +def start_phone_stats_home(hs): + """ + Start the background tasks which report phone home stats. + """ + clock = hs.get_clock() + + stats = {} + + def performance_stats_init(): + _stats_process.clear() + _stats_process.append( + (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) + ) + + # Rather than update on per session basis, batch up the requests. + # If you increase the loop period, the accuracy of user_daily_visits + # table will decrease + clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000) + + # monthly active user limiting functionality + clock.looping_call(hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60) + hs.get_datastore().reap_monthly_active_users() + + @wrap_as_background_process("generate_monthly_active_users") + async def generate_monthly_active_users(): + current_mau_count = 0 + current_mau_count_by_service = {} + reserved_users = () + store = hs.get_datastore() + if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + current_mau_count = await store.get_monthly_active_count() + current_mau_count_by_service = ( + await store.get_monthly_active_count_by_service() + ) + reserved_users = await store.get_registered_reserved_users() + current_mau_gauge.set(float(current_mau_count)) + + for app_service, count in current_mau_count_by_service.items(): + current_mau_by_service_gauge.labels(app_service).set(float(count)) + + registered_reserved_users_mau_gauge.set(float(len(reserved_users))) + max_mau_gauge.set(float(hs.config.max_mau_value)) + + if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + generate_monthly_active_users() + clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) + # End of monthly active user settings + + if hs.config.report_stats: + logger.info("Scheduling stats reporting for 3 hour intervals") + clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000, hs, stats) + + # We need to defer this init for the cases that we daemonize + # otherwise the process ID we get is that of the non-daemon process + clock.call_later(0, performance_stats_init) + + # We wait 5 minutes to send the first set of stats as the server can + # be quite busy the first few minutes + clock.call_later(5 * 60, phone_stats_home, hs, stats) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 13ec1f71a6..3944780a42 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py
@@ -14,14 +14,15 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable, List, Match, Optional from synapse.api.constants import EventTypes -from synapse.appservice.api import ApplicationServiceApi -from synapse.types import GroupID, get_domain_from_id -from synapse.util.caches.descriptors import cached +from synapse.events import EventBase +from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id +from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: + from synapse.appservice.api import ApplicationServiceApi from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -32,38 +33,6 @@ class ApplicationServiceState: UP = "up" -class AppServiceTransaction: - """Represents an application service transaction.""" - - def __init__(self, service, id, events): - self.service = service - self.id = id - self.events = events - - async def send(self, as_api: ApplicationServiceApi) -> bool: - """Sends this transaction using the provided AS API interface. - - Args: - as_api: The API to use to send. - Returns: - True if the transaction was sent. - """ - return await as_api.push_bulk( - service=self.service, events=self.events, txn_id=self.id - ) - - async def complete(self, store: "DataStore") -> None: - """Completes this transaction as successful. - - Marks this transaction ID on the application service and removes the - transaction contents from the database. - - Args: - store: The database store to operate on. - """ - await store.complete_appservice_txn(service=self.service, txn_id=self.id) - - class ApplicationService: """Defines an application service. This definition is mostly what is provided to the /register AS API. @@ -83,14 +52,15 @@ class ApplicationService: self, token, hostname, + id, + sender, url=None, namespaces=None, hs_token=None, - sender=None, - id=None, protocols=None, rate_limited=True, ip_range_whitelist=None, + supports_ephemeral=False, ): self.token = token self.url = ( @@ -102,6 +72,7 @@ class ApplicationService: self.namespaces = self._check_namespaces(namespaces) self.id = id self.ip_range_whitelist = ip_range_whitelist + self.supports_ephemeral = supports_ephemeral if "|" in self.id: raise Exception("application service ID cannot contain '|' character") @@ -161,19 +132,21 @@ class ApplicationService: raise ValueError("Expected string for 'regex' in ns '%s'" % ns) return namespaces - def _matches_regex(self, test_string, namespace_key): + def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]: for regex_obj in self.namespaces[namespace_key]: if regex_obj["regex"].match(test_string): return regex_obj return None - def _is_exclusive(self, ns_key, test_string): + def _is_exclusive(self, ns_key: str, test_string: str) -> bool: regex_obj = self._matches_regex(test_string, ns_key) if regex_obj: return regex_obj["exclusive"] return False - async def _matches_user(self, event, store): + async def _matches_user( + self, event: Optional[EventBase], store: Optional["DataStore"] = None + ) -> bool: if not event: return False @@ -188,11 +161,22 @@ class ApplicationService: if not store: return False - does_match = await self._matches_user_in_member_list(event.room_id, store) + does_match = await self.matches_user_in_member_list(event.room_id, store) return does_match @cached(num_args=1, cache_context=True) - async def _matches_user_in_member_list(self, room_id, store, cache_context): + async def matches_user_in_member_list( + self, room_id: str, store: "DataStore", cache_context: _CacheContext, + ) -> bool: + """Check if this service is interested a room based upon it's membership + + Args: + room_id: The room to check. + store: The datastore to query. + + Returns: + True if this service would like to know about this room. + """ member_list = await store.get_users_in_room( room_id, on_invalidate=cache_context.invalidate ) @@ -203,12 +187,14 @@ class ApplicationService: return True return False - def _matches_room_id(self, event): + def _matches_room_id(self, event: EventBase) -> bool: if hasattr(event, "room_id"): return self.is_interested_in_room(event.room_id) return False - async def _matches_aliases(self, event, store): + async def _matches_aliases( + self, event: EventBase, store: Optional["DataStore"] = None + ) -> bool: if not store or not event: return False @@ -218,12 +204,15 @@ class ApplicationService: return True return False - async def is_interested(self, event, store=None) -> bool: + async def is_interested( + self, event: EventBase, store: Optional["DataStore"] = None + ) -> bool: """Check if this service is interested in this event. Args: - event(Event): The event to check. - store(DataStore) + event: The event to check. + store: The datastore to query. + Returns: True if this service would like to know about this event. """ @@ -231,39 +220,66 @@ class ApplicationService: if self._matches_room_id(event): return True + # This will check the namespaces first before + # checking the store, so should be run before _matches_aliases + if await self._matches_user(event, store): + return True + + # This will check the store, so should be run last if await self._matches_aliases(event, store): return True - if await self._matches_user(event, store): + return False + + @cached(num_args=1) + async def is_interested_in_presence( + self, user_id: UserID, store: "DataStore" + ) -> bool: + """Check if this service is interested a user's presence + + Args: + user_id: The user to check. + store: The datastore to query. + + Returns: + True if this service would like to know about presence for this user. + """ + # Find all the rooms the sender is in + if self.is_interested_in_user(user_id.to_string()): return True + room_ids = await store.get_rooms_for_user(user_id.to_string()) + # Then find out if the appservice is interested in any of those rooms + for room_id in room_ids: + if await self.matches_user_in_member_list(room_id, store): + return True return False - def is_interested_in_user(self, user_id): + def is_interested_in_user(self, user_id: str) -> bool: return ( - self._matches_regex(user_id, ApplicationService.NS_USERS) + bool(self._matches_regex(user_id, ApplicationService.NS_USERS)) or user_id == self.sender ) - def is_interested_in_alias(self, alias): + def is_interested_in_alias(self, alias: str) -> bool: return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) - def is_interested_in_room(self, room_id): + def is_interested_in_room(self, room_id: str) -> bool: return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) - def is_exclusive_user(self, user_id): + def is_exclusive_user(self, user_id: str) -> bool: return ( self._is_exclusive(ApplicationService.NS_USERS, user_id) or user_id == self.sender ) - def is_interested_in_protocol(self, protocol): + def is_interested_in_protocol(self, protocol: str) -> bool: return protocol in self.protocols - def is_exclusive_alias(self, alias): + def is_exclusive_alias(self, alias: str) -> bool: return self._is_exclusive(ApplicationService.NS_ALIASES, alias) - def is_exclusive_room(self, room_id): + def is_exclusive_room(self, room_id: str) -> bool: return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) def get_exclusive_user_regexes(self): @@ -276,14 +292,14 @@ class ApplicationService: if regex_obj["exclusive"] ] - def get_groups_for_user(self, user_id): + def get_groups_for_user(self, user_id: str) -> Iterable[str]: """Get the groups that this user is associated with by this AS Args: - user_id (str): The ID of the user. + user_id: The ID of the user. Returns: - iterable[str]: an iterable that yields group_id strings. + An iterable that yields group_id strings. """ return ( regex_obj["group_id"] @@ -291,7 +307,7 @@ class ApplicationService: if "group_id" in regex_obj and regex_obj["regex"].match(user_id) ) - def is_rate_limited(self): + def is_rate_limited(self) -> bool: return self.rate_limited def __str__(self): @@ -300,3 +316,45 @@ class ApplicationService: dict_copy["token"] = "<redacted>" dict_copy["hs_token"] = "<redacted>" return "ApplicationService: %s" % (dict_copy,) + + +class AppServiceTransaction: + """Represents an application service transaction.""" + + def __init__( + self, + service: ApplicationService, + id: int, + events: List[EventBase], + ephemeral: List[JsonDict], + ): + self.service = service + self.id = id + self.events = events + self.ephemeral = ephemeral + + async def send(self, as_api: "ApplicationServiceApi") -> bool: + """Sends this transaction using the provided AS API interface. + + Args: + as_api: The API to use to send. + Returns: + True if the transaction was sent. + """ + return await as_api.push_bulk( + service=self.service, + events=self.events, + ephemeral=self.ephemeral, + txn_id=self.id, + ) + + async def complete(self, store: "DataStore") -> None: + """Completes this transaction as successful. + + Marks this transaction ID on the application service and removes the + transaction contents from the database. + + Args: + store: The database store to operate on. + """ + await store.complete_appservice_txn(service=self.service, txn_id=self.id) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index c526c28b93..e366a982b8 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py
@@ -14,12 +14,13 @@ # limitations under the License. import logging import urllib -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple from prometheus_client import Counter from synapse.api.constants import EventTypes, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException +from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict, ThirdPartyInstanceID @@ -93,7 +94,7 @@ class ApplicationServiceApi(SimpleHttpClient): self.protocol_meta_cache = ResponseCache( hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS - ) + ) # type: ResponseCache[Tuple[str, str]] async def query_user(self, service, user_id): if service.url is None: @@ -201,7 +202,13 @@ class ApplicationServiceApi(SimpleHttpClient): key = (service.id, protocol) return await self.protocol_meta_cache.wrap(key, _get) - async def push_bulk(self, service, events, txn_id=None): + async def push_bulk( + self, + service: "ApplicationService", + events: List[EventBase], + ephemeral: List[JsonDict], + txn_id: Optional[int] = None, + ): if service.url is None: return True @@ -211,15 +218,19 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning( "push_bulk: Missing txn ID sending events to %s", service.url ) - txn_id = str(0) - txn_id = str(txn_id) + txn_id = 0 + + uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) + + # Never send ephemeral events to appservices that do not support it + if service.supports_ephemeral: + body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral} + else: + body = {"events": events} - uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id)) try: await self.put_json( - uri=uri, - json_body={"events": events}, - args={"access_token": service.hs_token}, + uri=uri, json_body=body, args={"access_token": service.hs_token}, ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 8eb8c6f51c..58291afc22 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py
@@ -49,14 +49,24 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging +from typing import List -from synapse.appservice import ApplicationServiceState +from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import JsonDict logger = logging.getLogger(__name__) +# Maximum number of events to provide in an AS transaction. +MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100 + +# Maximum number of ephemeral events to provide in an AS transaction. +MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100 + + class ApplicationServiceScheduler: """ Public facing API for this module. Does the required DI to tie the components together. This also serves as the "event_pool", which in this @@ -82,8 +92,13 @@ class ApplicationServiceScheduler: for service in services: self.txn_ctrl.start_recoverer(service) - def submit_event_for_as(self, service, event): - self.queuer.enqueue(service, event) + def submit_event_for_as(self, service: ApplicationService, event: EventBase): + self.queuer.enqueue_event(service, event) + + def submit_ephemeral_events_for_as( + self, service: ApplicationService, events: List[JsonDict] + ): + self.queuer.enqueue_ephemeral(service, events) class _ServiceQueuer: @@ -96,17 +111,15 @@ class _ServiceQueuer: def __init__(self, txn_ctrl, clock): self.queued_events = {} # dict of {service_id: [events]} + self.queued_ephemeral = {} # dict of {service_id: [events]} # the appservices which currently have a transaction in flight self.requests_in_flight = set() self.txn_ctrl = txn_ctrl self.clock = clock - def enqueue(self, service, event): - self.queued_events.setdefault(service.id, []).append(event) - + def _start_background_request(self, service): # start a sender for this appservice if we don't already have one - if service.id in self.requests_in_flight: return @@ -114,7 +127,15 @@ class _ServiceQueuer: "as-sender-%s" % (service.id,), self._send_request, service ) - async def _send_request(self, service): + def enqueue_event(self, service: ApplicationService, event: EventBase): + self.queued_events.setdefault(service.id, []).append(event) + self._start_background_request(service) + + def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]): + self.queued_ephemeral.setdefault(service.id, []).extend(events) + self._start_background_request(service) + + async def _send_request(self, service: ApplicationService): # sanity-check: we shouldn't get here if this service already has a sender # running. assert service.id not in self.requests_in_flight @@ -122,11 +143,19 @@ class _ServiceQueuer: self.requests_in_flight.add(service.id) try: while True: - events = self.queued_events.pop(service.id, []) - if not events: + all_events = self.queued_events.get(service.id, []) + events = all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION] + del all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION] + + all_events_ephemeral = self.queued_ephemeral.get(service.id, []) + ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] + del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] + + if not events and not ephemeral: return + try: - await self.txn_ctrl.send(service, events) + await self.txn_ctrl.send(service, events, ephemeral) except Exception: logger.exception("AS request failed") finally: @@ -158,9 +187,16 @@ class _TransactionController: # for UTs self.RECOVERER_CLASS = _Recoverer - async def send(self, service, events): + async def send( + self, + service: ApplicationService, + events: List[EventBase], + ephemeral: List[JsonDict] = [], + ): try: - txn = await self.store.create_appservice_txn(service=service, events=events) + txn = await self.store.create_appservice_txn( + service=service, events=events, ephemeral=ephemeral + ) service_is_up = await self._is_service_up(service) if service_is_up: sent = await txn.send(self.as_api) @@ -204,7 +240,7 @@ class _TransactionController: recoverer.recover() logger.info("Now %i active recoverers", len(self.recoverers)) - async def _is_service_up(self, service): + async def _is_service_up(self, service: ApplicationService) -> bool: state = await self.store.get_appservice_state(service) return state == ApplicationServiceState.UP or state is None diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 85f65da4d9..2931a88207 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -23,7 +23,7 @@ import urllib.parse from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Callable, List, MutableMapping, Optional +from typing import Any, Callable, Iterable, List, MutableMapping, Optional import attr import jinja2 @@ -32,7 +32,17 @@ import yaml class ConfigError(Exception): - pass + """Represents a problem parsing the configuration + + Args: + msg: A textual description of the error. + path: Where appropriate, an indication of where in the configuration + the problem lies. + """ + + def __init__(self, msg: str, path: Optional[Iterable[str]] = None): + self.msg = msg + self.path = path # We split these messages out to allow packages to override with package diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index b8faafa9bd..29aa064e57 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi
@@ -1,8 +1,9 @@ -from typing import Any, List, Optional +from typing import Any, Iterable, List, Optional from synapse.config import ( api, appservice, + auth, captcha, cas, consent_config, @@ -14,7 +15,6 @@ from synapse.config import ( logger, metrics, oidc_config, - password, password_auth_providers, push, ratelimiting, @@ -35,7 +35,10 @@ from synapse.config import ( workers, ) -class ConfigError(Exception): ... +class ConfigError(Exception): + def __init__(self, msg: str, path: Optional[Iterable[str]] = None): + self.msg = msg + self.path = path MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str MISSING_REPORT_STATS_SPIEL: str @@ -62,7 +65,7 @@ class RootConfig: sso: sso.SSOConfig oidc: oidc_config.OIDCConfig jwt: jwt_config.JWTConfig - password: password.PasswordConfig + auth: auth.AuthConfig email: emailconfig.EmailConfig worker: workers.WorkerConfig authproviders: password_auth_providers.PasswordAuthProviderConfig diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index c74969a977..1bbe83c317 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py
@@ -38,14 +38,27 @@ def validate_config( try: jsonschema.validate(config, json_schema) except jsonschema.ValidationError as e: - # copy `config_path` before modifying it. - path = list(config_path) - for p in list(e.path): - if isinstance(p, int): - path.append("<item %i>" % p) - else: - path.append(str(p)) - - raise ConfigError( - "Unable to parse configuration: %s at %s" % (e.message, ".".join(path)) - ) + raise json_error_to_config_error(e, config_path) + + +def json_error_to_config_error( + e: jsonschema.ValidationError, config_path: Iterable[str] +) -> ConfigError: + """Converts a json validation error to a user-readable ConfigError + + Args: + e: the exception to be converted + config_path: the path within the config file. This will be used as a basis + for the error message. + + Returns: + a ConfigError + """ + # copy `config_path` before modifying it. + path = list(config_path) + for p in list(e.path): + if isinstance(p, int): + path.append("<item %i>" % p) + else: + path.append(str(p)) + return ConfigError(e.message, path) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 8ed3e24258..746fc3cc02 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py
@@ -160,6 +160,8 @@ def _load_appservice(hostname, as_info, config_filename): if as_info.get("ip_range_whitelist"): ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist")) + supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False) + return ApplicationService( token=as_info["as_token"], hostname=hostname, @@ -168,6 +170,7 @@ def _load_appservice(hostname, as_info, config_filename): hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], + supports_ephemeral=supports_ephemeral, protocols=protocols, rate_limited=rate_limited, ip_range_whitelist=ip_range_whitelist, diff --git a/synapse/config/password.py b/synapse/config/auth.py
index 9c0ea8c30a..2b3e2ce87b 100644 --- a/synapse/config/password.py +++ b/synapse/config/auth.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +17,11 @@ from ._base import Config -class PasswordConfig(Config): - """Password login configuration +class AuthConfig(Config): + """Password and login configuration """ - section = "password" + section = "auth" def read_config(self, config, **kwargs): password_config = config.get("password_config", {}) @@ -35,6 +36,10 @@ class PasswordConfig(Config): self.password_policy = password_config.get("policy") or {} self.password_policy_enabled = self.password_policy.get("enabled", False) + # User-interactive authentication + ui_auth = config.get("ui_auth") or {} + self.ui_auth_session_timeout = ui_auth.get("session_timeout", 0) + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ password_config: @@ -87,4 +92,19 @@ class PasswordConfig(Config): # Defaults to 'false'. # #require_uppercase: true + + ui_auth: + # The number of milliseconds to allow a user-interactive authentication + # session to be active. + # + # This defaults to 0, meaning the user is queried for their credentials + # before every action, but this can be overridden to alow a single + # validation to be re-used. This weakens the protections afforded by + # the user-interactive authentication process, by allowing for multiple + # (and potentially different) operations to use the same validation session. + # + # Uncomment below to allow for credential validation to last for 15 + # seconds. + # + #session_timeout: 15000 """ diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 4526c1a67b..2f97e6d258 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py
@@ -26,14 +26,14 @@ class CasConfig(Config): def read_config(self, config, **kwargs): cas_config = config.get("cas_config", None) - if cas_config: - self.cas_enabled = cas_config.get("enabled", True) + self.cas_enabled = cas_config and cas_config.get("enabled", True) + + if self.cas_enabled: self.cas_server_url = cas_config["server_url"] self.cas_service_url = cas_config["service_url"] self.cas_displayname_attribute = cas_config.get("displayname_attribute") - self.cas_required_attributes = cas_config.get("required_attributes", {}) + self.cas_required_attributes = cas_config.get("required_attributes") or {} else: - self.cas_enabled = False self.cas_server_url = None self.cas_service_url = None self.cas_displayname_attribute = None @@ -41,13 +41,35 @@ class CasConfig(Config): def generate_config_section(self, config_dir_path, server_name, **kwargs): return """ - # Enable CAS for registration and login. + # Enable Central Authentication Service (CAS) for registration and login. # - #cas_config: - # enabled: true - # server_url: "https://cas-server.com" - # service_url: "https://homeserver.domain.com:8448" - # #displayname_attribute: name - # #required_attributes: - # # name: value + cas_config: + # Uncomment the following to enable authorization against a CAS server. + # Defaults to false. + # + #enabled: true + + # The URL of the CAS authorization endpoint. + # + #server_url: "https://cas-server.com" + + # The public URL of the homeserver. + # + #service_url: "https://homeserver.domain.com:8448" + + # The attribute of the CAS response to use as the display name. + # + # If unset, no displayname will be set. + # + #displayname_attribute: name + + # It is possible to configure Synapse to only allow logins if CAS attributes + # match particular values. All of the keys in the mapping below must exist + # and the values must match the given value. Alternately if the given value + # is None then any value is allowed (the attribute just must exist). + # All of the listed attributes must match for the login to be permitted. + # + #required_attributes: + # userGroup: "staff" + # department: None """ diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index cceffbfee2..d4328c46b9 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py
@@ -322,6 +322,22 @@ class EmailConfig(Config): self.email_subjects = EmailSubjectConfig(**subjects) + # The invite client location should be a HTTP(S) URL or None. + self.invite_client_location = email_config.get("invite_client_location") or None + if self.invite_client_location: + if not isinstance(self.invite_client_location, str): + raise ConfigError( + "Config option email.invite_client_location must be type str" + ) + if not ( + self.invite_client_location.startswith("http://") + or self.invite_client_location.startswith("https://") + ): + raise ConfigError( + "Config option email.invite_client_location must be a http or https URL", + path=("email", "invite_client_location"), + ) + def generate_config_section(self, config_dir_path, server_name, **kwargs): return ( """\ @@ -389,10 +405,15 @@ class EmailConfig(Config): # #validation_token_lifetime: 15m - # Directory in which Synapse will try to find the template files below. - # If not set, default templates from within the Synapse package will be used. + # The web client location to direct users to during an invite. This is passed + # to the identity server as the org.matrix.web_client_location key. Defaults + # to unset, giving no guidance to the identity server. # - # Do not uncomment this setting unless you want to customise the templates. + #invite_client_location: https://app.element.io + + # Directory in which Synapse will try to find the template files below. + # If not set, or the files named below are not found within the template + # directory, default templates from within the Synapse package will be used. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index ffd8fca54e..9f3c57e6a1 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py
@@ -12,12 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Optional -from netaddr import IPSet - -from synapse.config._base import Config, ConfigError +from synapse.config._base import Config from synapse.config._util import validate_config @@ -36,23 +33,6 @@ class FederationConfig(Config): for domain in federation_domain_whitelist: self.federation_domain_whitelist[domain] = True - self.federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", [] - ) - - # Attempt to create an IPSet from the given ranges - try: - self.federation_ip_range_blacklist = IPSet( - self.federation_ip_range_blacklist - ) - - # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - except Exception as e: - raise ConfigError( - "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e - ) - federation_metrics_domains = config.get("federation_metrics_domains") or [] validate_config( _METRICS_FOR_DOMAINS_SCHEMA, @@ -76,27 +56,6 @@ class FederationConfig(Config): # - nyc.example.com # - syd.example.com - # Prevent federation requests from being sent to the following - # blacklist IP address CIDR ranges. If this option is not specified, or - # specified with an empty list, no ip range blacklist will be enforced. - # - # As of Synapse v1.4.0 this option also affects any outbound requests to identity - # servers provided by user input. - # - # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly - # listed here, since they correspond to unroutable addresses.) - # - federation_ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' - # Report prometheus metrics on the age of PDUs being sent to and received from # the following domains. This can be used to give an idea of "delay" on inbound # and outbound federation, though be aware that any delay can be due to problems diff --git a/synapse/config/groups.py b/synapse/config/groups.py
index d6862d9a64..7b7860ea71 100644 --- a/synapse/config/groups.py +++ b/synapse/config/groups.py
@@ -32,5 +32,5 @@ class GroupsConfig(Config): # If enabled, non server admins can only create groups with local parts # starting with this prefix # - #group_creation_prefix: "unofficial/" + #group_creation_prefix: "unofficial_" """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index be65554524..4bd2b3587b 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py
@@ -17,6 +17,7 @@ from ._base import RootConfig from .api import ApiConfig from .appservice import AppServiceConfig +from .auth import AuthConfig from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig @@ -30,7 +31,6 @@ from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig from .oidc_config import OIDCConfig -from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig from .ratelimiting import RatelimitConfig @@ -76,7 +76,7 @@ class HomeServerConfig(RootConfig): CasConfig, SSOConfig, JWTConfig, - PasswordConfig, + AuthConfig, EmailConfig, PasswordAuthProviderConfig, PushConfig, diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index 3252ad9e7f..f30330abb6 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py
@@ -63,7 +63,7 @@ class JWTConfig(Config): # and issued at ("iat") claims are validated if present. # # Note that this is a non-standard login type and client support is - # expected to be non-existant. + # expected to be non-existent. # # See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md. # diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 13d6f6a3ea..4df3f93c1c 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py
@@ -23,7 +23,6 @@ from string import Template import yaml from twisted.logger import ( - ILogObserver, LogBeginner, STDLibLogObserver, eventAsText, @@ -32,11 +31,9 @@ from twisted.logger import ( import synapse from synapse.app import _base as appbase -from synapse.logging._structured import ( - reload_structured_logging, - setup_structured_logging, -) +from synapse.logging._structured import setup_structured_logging from synapse.logging.context import LoggingContextFilter +from synapse.logging.filter import MetadataFilter from synapse.util.versionstring import get_version_string from ._base import Config, ConfigError @@ -48,7 +45,11 @@ DEFAULT_LOG_CONFIG = Template( # This is a YAML file containing a standard Python logging configuration # dictionary. See [1] for details on the valid settings. # +# Synapse also supports structured logging for machine readable logs which can +# be ingested by ELK stacks. See [2] for details. +# # [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema +# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md version: 1 @@ -105,7 +106,7 @@ root: # then write them to a file. # # Replace "buffer" with "console" to log to stderr instead. (Note that you'll - # also need to update the configuation for the `twisted` logger above, in + # also need to update the configuration for the `twisted` logger above, in # this case.) # handlers: [buffer] @@ -176,11 +177,11 @@ class LoggingConfig(Config): log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file)) -def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): +def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None: """ - Set up Python stdlib logging. + Set up Python standard library logging. """ - if log_config is None: + if log_config_path is None: log_format = ( "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" " - %(message)s" @@ -196,7 +197,8 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): handler.setFormatter(formatter) logger.addHandler(handler) else: - logging.config.dictConfig(log_config) + # Load the logging configuration. + _load_logging_config(log_config_path) # We add a log record factory that runs all messages through the # LoggingContextFilter so that we get the context *at the time we log* @@ -204,12 +206,14 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): # filter options, but care must when using e.g. MemoryHandler to buffer # writes. - log_filter = LoggingContextFilter(request="") + log_context_filter = LoggingContextFilter() + log_metadata_filter = MetadataFilter({"server_name": config.server_name}) old_factory = logging.getLogRecordFactory() def factory(*args, **kwargs): record = old_factory(*args, **kwargs) - log_filter.filter(record) + log_context_filter.filter(record) + log_metadata_filter.filter(record) return record logging.setLogRecordFactory(factory) @@ -255,21 +259,40 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): if not config.no_redirect_stdio: print("Redirected stdout/stderr to logs") - return observer - -def _reload_stdlib_logging(*args, log_config=None): - logger = logging.getLogger("") +def _load_logging_config(log_config_path: str) -> None: + """ + Configure logging from a log config path. + """ + with open(log_config_path, "rb") as f: + log_config = yaml.safe_load(f.read()) if not log_config: - logger.warning("Reloaded a blank config?") + logging.warning("Loaded a blank logging config?") + + # If the old structured logging configuration is being used, convert it to + # the new style configuration. + if "structured" in log_config and log_config.get("structured"): + log_config = setup_structured_logging(log_config) logging.config.dictConfig(log_config) +def _reload_logging_config(log_config_path): + """ + Reload the log configuration from the file and apply it. + """ + # If no log config path was given, it cannot be reloaded. + if log_config_path is None: + return + + _load_logging_config(log_config_path) + logging.info("Reloaded log config from %s due to SIGHUP", log_config_path) + + def setup_logging( hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner -) -> ILogObserver: +) -> None: """ Set up the logging subsystem. @@ -282,41 +305,18 @@ def setup_logging( logBeginner: The Twisted logBeginner to use. - Returns: - The "root" Twisted Logger observer, suitable for sending logs to from a - Logger instance. """ - log_config = config.worker_log_config if use_worker_options else config.log_config - - def read_config(*args, callback=None): - if log_config is None: - return None - - with open(log_config, "rb") as f: - log_config_body = yaml.safe_load(f.read()) - - if callback: - callback(log_config=log_config_body) - logging.info("Reloaded log config from %s due to SIGHUP", log_config) - - return log_config_body + log_config_path = ( + config.worker_log_config if use_worker_options else config.log_config + ) - log_config_body = read_config() + # Perform one-time logging configuration. + _setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner) + # Add a SIGHUP handler to reload the logging configuration, if one is available. + appbase.register_sighup(_reload_logging_config, log_config_path) - if log_config_body and log_config_body.get("structured") is True: - logger = setup_structured_logging( - hs, config, log_config_body, logBeginner=logBeginner - ) - appbase.register_sighup(read_config, callback=reload_structured_logging) - else: - logger = _setup_stdlib_logging(config, log_config_body, logBeginner=logBeginner) - appbase.register_sighup(read_config, callback=_reload_stdlib_logging) - - # make sure that the first thing we log is a thing we can grep backwards - # for + # Log immediately so we can grep backwards. logging.warning("***** STARTING SERVER *****") logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) logging.info("Instance name: %s", hs.get_instance_name()) - - return logger diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index f924116819..4e3055282d 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py
@@ -56,6 +56,7 @@ class OIDCConfig(Config): self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") self.oidc_jwks_uri = oidc_config.get("jwks_uri") self.oidc_skip_verification = oidc_config.get("skip_verification", False) + self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto") self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False) ump_config = oidc_config.get("user_mapping_provider", {}) @@ -65,7 +66,7 @@ class OIDCConfig(Config): ( self.oidc_user_mapping_provider_class, self.oidc_user_mapping_provider_config, - ) = load_module(ump_config) + ) = load_module(ump_config, ("oidc_config", "user_mapping_provider")) # Ensure loaded user mapping module has defined all necessary methods required_methods = [ @@ -86,11 +87,10 @@ class OIDCConfig(Config): def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ - # OpenID Connect integration. The following settings can be used to make Synapse - # use an OpenID Connect Provider for authentication, instead of its internal - # password database. + # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. # - # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md. + # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md + # for some example configurations. # oidc_config: # Uncomment the following to enable authorization against an OpenID Connect @@ -159,6 +159,14 @@ class OIDCConfig(Config): # #skip_verification: true + # Whether to fetch the user profile from the userinfo endpoint. Valid + # values are: "auto" or "userinfo_endpoint". + # + # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included + # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. + # + #user_profile_method: "userinfo_endpoint" + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead # of failing. This could be used if switching from password logins to OIDC. Defaults to false. # @@ -195,9 +203,10 @@ class OIDCConfig(Config): # * user: The claims returned by the UserInfo Endpoint and/or in the ID # Token # - # This must be configured if using the default mapping provider. + # If this is not set, the user will be prompted to choose their + # own username. # - localpart_template: "{{{{ user.preferred_username }}}}" + #localpart_template: "{{{{ user.preferred_username }}}}" # Jinja2 template for the display name to set on first login. # diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 4fda8ae987..85d07c4f8f 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py
@@ -36,7 +36,7 @@ class PasswordAuthProviderConfig(Config): providers.append({"module": LDAP_PROVIDER, "config": ldap_config}) providers.extend(config.get("password_providers") or []) - for provider in providers: + for i, provider in enumerate(providers): mod_name = provider["module"] # This is for backwards compat when the ldap auth provider resided @@ -45,7 +45,8 @@ class PasswordAuthProviderConfig(Config): mod_name = LDAP_PROVIDER (provider_class, provider_config) = load_module( - {"module": mod_name, "config": provider["config"]} + {"module": mod_name, "config": provider["config"]}, + ("password_providers", "<item %i>" % i), ) self.password_providers.append((provider_class, provider_config)) diff --git a/synapse/config/push.py b/synapse/config/push.py
index a1f3752c8a..3adbfb73e6 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py
@@ -21,8 +21,11 @@ class PushConfig(Config): section = "push" def read_config(self, config, **kwargs): - push_config = config.get("push", {}) + push_config = config.get("push") or {} self.push_include_content = push_config.get("include_content", True) + self.push_group_unread_count_by_room = push_config.get( + "group_unread_count_by_room", True + ) pusher_instances = config.get("pusher_instances") or [] self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) @@ -49,18 +52,33 @@ class PushConfig(Config): def generate_config_section(self, config_dir_path, server_name, **kwargs): return """ - # Clients requesting push notifications can either have the body of - # the message sent in the notification poke along with other details - # like the sender, or just the event ID and room ID (`event_id_only`). - # If clients choose the former, this option controls whether the - # notification request includes the content of the event (other details - # like the sender are still included). For `event_id_only` push, it - # has no effect. - # - # For modern android devices the notification content will still appear - # because it is loaded by the app. iPhone, however will send a - # notification saying only that a message arrived and who it came from. - # - #push: - # include_content: true + ## Push ## + + push: + # Clients requesting push notifications can either have the body of + # the message sent in the notification poke along with other details + # like the sender, or just the event ID and room ID (`event_id_only`). + # If clients choose the former, this option controls whether the + # notification request includes the content of the event (other details + # like the sender are still included). For `event_id_only` push, it + # has no effect. + # + # For modern android devices the notification content will still appear + # because it is loaded by the app. iPhone, however will send a + # notification saying only that a message arrived and who it came from. + # + # The default value is "true" to include message details. Uncomment to only + # include the event ID and room ID in push notification payloads. + # + #include_content: false + + # When a push notification is received, an unread count is also sent. + # This number can either be calculated as the number of unread messages + # for the user, or the number of *rooms* the user has unread messages in. + # + # The default value is "true", meaning push clients will see the number of + # rooms with unread messages in them. Uncomment to instead send the number + # of unread messages. + # + #group_unread_count_by_room: false """ diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index d7e3690a32..cc5f75123c 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py
@@ -143,7 +143,7 @@ class RegistrationConfig(Config): RoomCreationPreset.TRUSTED_PRIVATE_CHAT, } - # Pull the creater/inviter from the configuration, this gets used to + # Pull the creator/inviter from the configuration, this gets used to # send invites for invite-only rooms. mxid_localpart = config.get("auto_join_mxid_localpart") self.auto_join_user_id = None @@ -347,8 +347,9 @@ class RegistrationConfig(Config): # email will be globally disabled. # # Additionally, if `msisdn` is not set, registration and password resets via msisdn - # will be disabled regardless. This is due to Synapse currently not supporting any - # method of sending SMS messages on its own. + # will be disabled regardless, and users will not be able to associate an msisdn + # identifier to their account. This is due to Synapse currently not supporting + # any method of sending SMS messages on its own. # # To enable using an identity server for operations regarding a particular third-party # identifier type, set the value to the URL of that identity server as shown in the diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 01009f3924..850ac3ebd6 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py
@@ -17,6 +17,9 @@ import os from collections import namedtuple from typing import Dict, List +from netaddr import IPSet + +from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module @@ -100,7 +103,7 @@ class ContentRepositoryConfig(Config): "media_instance_running_background_jobs", ) - self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M")) + self.max_upload_size = self.parse_size(config.get("max_upload_size", "50M")) self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M")) self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M")) @@ -142,7 +145,7 @@ class ContentRepositoryConfig(Config): # them to be started. self.media_storage_providers = [] # type: List[tuple] - for provider_config in storage_providers: + for i, provider_config in enumerate(storage_providers): # We special case the module "file_system" so as not to need to # expose FileStorageProviderBackend if provider_config["module"] == "file_system": @@ -151,7 +154,9 @@ class ContentRepositoryConfig(Config): ".FileStorageProviderBackend" ) - provider_class, parsed_config = load_module(provider_config) + provider_class, parsed_config = load_module( + provider_config, ("media_storage_providers", "<item %i>" % i) + ) wrapper_config = MediaStorageProviderConfig( provider_config.get("store_local", False), @@ -182,9 +187,6 @@ class ContentRepositoryConfig(Config): "to work" ) - # netaddr is a dependency for url_preview - from netaddr import IPSet - self.url_preview_ip_range_blacklist = IPSet( config["url_preview_ip_range_blacklist"] ) @@ -213,6 +215,10 @@ class ContentRepositoryConfig(Config): # strip final NL formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1] + ip_range_blacklist = "\n".join( + " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST + ) + return ( r""" ## Media Store ## @@ -242,7 +248,7 @@ class ContentRepositoryConfig(Config): # The largest allowed upload size in bytes # - #max_upload_size: 10M + #max_upload_size: 50M # Maximum number of pixels that will be thumbnailed # @@ -283,15 +289,7 @@ class ContentRepositoryConfig(Config): # you uncomment the following list as a starting point. # #url_preview_ip_range_blacklist: - # - '127.0.0.0/8' - # - '10.0.0.0/8' - # - '172.16.0.0/12' - # - '192.168.0.0/16' - # - '100.64.0.0/10' - # - '169.254.0.0/16' - # - '::1/128' - # - 'fe80::/64' - # - 'fc00::/7' +%(ip_range_blacklist)s # List of IP address CIDR ranges that the URL preview spider is allowed # to access even if they are specified in url_preview_ip_range_blacklist. diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 6de1f9d103..9a3e1c3e7d 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py
@@ -99,7 +99,7 @@ class RoomDirectoryConfig(Config): # # Options for the rules include: # - # user_id: Matches agaisnt the creator of the alias + # user_id: Matches against the creator of the alias # room_id: Matches against the room ID being published # alias: Matches against any current local or canonical aliases # associated with the room @@ -180,7 +180,7 @@ class _RoomDirectoryRule: self._alias_regex = glob_to_regex(alias) self._room_id_regex = glob_to_regex(room_id) except Exception as e: - raise ConfigError("Failed to parse glob into regex: %s", e) + raise ConfigError("Failed to parse glob into regex") from e def matches(self, user_id, room_id, aliases): """Tests if this rule matches the given user_id, room_id and aliases. diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 99aa8b3bf1..7b97d4f114 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py
@@ -90,6 +90,8 @@ class SAML2Config(Config): "grandfathered_mxid_source_attribute", "uid" ) + self.saml2_idp_entityid = saml2_config.get("idp_entityid", None) + # user_mapping_provider may be None if the key is present but has no value ump_dict = saml2_config.get("user_mapping_provider") or {} @@ -123,7 +125,7 @@ class SAML2Config(Config): ( self.saml2_user_mapping_provider_class, self.saml2_user_mapping_provider_config, - ) = load_module(ump_dict) + ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider")) # Ensure loaded user mapping module has defined all necessary methods # Note parse_config() is already checked during the call to load_module @@ -216,10 +218,8 @@ class SAML2Config(Config): return """\ ## Single sign-on integration ## - # Enable SAML2 for registration and login. Uses pysaml2. - # - # At least one of `sp_config` or `config_path` must be set in this section to - # enable SAML login. + # The following settings can be used to make Synapse use a single sign-on + # provider for authentication, instead of its internal password database. # # You will probably also want to set the following options to `false` to # disable the regular login/registration flows: @@ -228,6 +228,11 @@ class SAML2Config(Config): # # You will also want to investigate the settings under the "sso" configuration # section below. + + # Enable SAML2 for registration and login. Uses pysaml2. + # + # At least one of `sp_config` or `config_path` must be set in this section to + # enable SAML login. # # Once SAML support is enabled, a metadata file will be exposed at # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to @@ -243,40 +248,70 @@ class SAML2Config(Config): # so it is not normally necessary to specify them unless you need to # override them. # - #sp_config: - # # point this to the IdP's metadata. You can use either a local file or - # # (preferably) a URL. - # metadata: - # #local: ["saml2/idp.xml"] - # remote: - # - url: https://our_idp/metadata.xml - # - # # By default, the user has to go to our login page first. If you'd like - # # to allow IdP-initiated login, set 'allow_unsolicited: true' in a - # # 'service.sp' section: - # # - # #service: - # # sp: - # # allow_unsolicited: true - # - # # The examples below are just used to generate our metadata xml, and you - # # may well not need them, depending on your setup. Alternatively you - # # may need a whole lot more detail - see the pysaml2 docs! - # - # description: ["My awesome SP", "en"] - # name: ["Test SP", "en"] - # - # organization: - # name: Example com - # display_name: - # - ["Example co", "en"] - # url: "http://example.com" - # - # contact_person: - # - given_name: Bob - # sur_name: "the Sysadmin" - # email_address": ["admin@example.com"] - # contact_type": technical + sp_config: + # Point this to the IdP's metadata. You must provide either a local + # file via the `local` attribute or (preferably) a URL via the + # `remote` attribute. + # + #metadata: + # local: ["saml2/idp.xml"] + # remote: + # - url: https://our_idp/metadata.xml + + # Allowed clock difference in seconds between the homeserver and IdP. + # + # Uncomment the below to increase the accepted time difference from 0 to 3 seconds. + # + #accepted_time_diff: 3 + + # By default, the user has to go to our login page first. If you'd like + # to allow IdP-initiated login, set 'allow_unsolicited: true' in a + # 'service.sp' section: + # + #service: + # sp: + # allow_unsolicited: true + + # The examples below are just used to generate our metadata xml, and you + # may well not need them, depending on your setup. Alternatively you + # may need a whole lot more detail - see the pysaml2 docs! + + #description: ["My awesome SP", "en"] + #name: ["Test SP", "en"] + + #ui_info: + # display_name: + # - lang: en + # text: "Display Name is the descriptive name of your service." + # description: + # - lang: en + # text: "Description should be a short paragraph explaining the purpose of the service." + # information_url: + # - lang: en + # text: "https://example.com/terms-of-service" + # privacy_statement_url: + # - lang: en + # text: "https://example.com/privacy-policy" + # keywords: + # - lang: en + # text: ["Matrix", "Element"] + # logo: + # - lang: en + # text: "https://example.com/logo.svg" + # width: "200" + # height: "80" + + #organization: + # name: Example com + # display_name: + # - ["Example co", "en"] + # url: "http://example.com" + + #contact_person: + # - given_name: Bob + # sur_name: "the Sysadmin" + # email_address": ["admin@example.com"] + # contact_type": technical # Instead of putting the config inline as above, you can specify a # separate pysaml2 configuration file: @@ -350,6 +385,14 @@ class SAML2Config(Config): # value: "staff" # - attribute: department # value: "sales" + + # If the metadata XML contains multiple IdP entities then the `idp_entityid` + # option must be set to the entity to redirect users to. + # + # Most deployments only have a single IdP entity and so should omit this + # option. + # + #idp_entityid: 'https://our_idp/entityid' """ % { "config_dir_path": config_dir_path } diff --git a/synapse/config/server.py b/synapse/config/server.py
index ef6d70e3f8..7242a4aa8e 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -23,6 +23,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set import attr import yaml +from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name @@ -39,7 +40,35 @@ logger = logging.Logger(__name__) # in the list. DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"] -DEFAULT_ROOM_VERSION = "5" +DEFAULT_IP_RANGE_BLACKLIST = [ + # Localhost + "127.0.0.0/8", + # Private networks. + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + # Carrier grade NAT. + "100.64.0.0/10", + # Address registry. + "192.0.0.0/24", + # Link-local networks. + "169.254.0.0/16", + # Testing networks. + "198.18.0.0/15", + "192.0.2.0/24", + "198.51.100.0/24", + "203.0.113.0/24", + # Multicast. + "224.0.0.0/4", + # Localhost + "::1/128", + # Link-local addresses. + "fe80::/10", + # Unique local addresses. + "fc00::/7", +] + +DEFAULT_ROOM_VERSION = "6" ROOM_COMPLEXITY_TOO_GREAT = ( "Your homeserver is unable to join rooms this large or complex. " @@ -256,6 +285,38 @@ class ServerConfig(Config): # due to resource constraints self.admin_contact = config.get("admin_contact", None) + ip_range_blacklist = config.get( + "ip_range_blacklist", DEFAULT_IP_RANGE_BLACKLIST + ) + + # Attempt to create an IPSet from the given ranges + try: + self.ip_range_blacklist = IPSet(ip_range_blacklist) + except Exception as e: + raise ConfigError("Invalid range(s) provided in ip_range_blacklist.") from e + # Always blacklist 0.0.0.0, :: + self.ip_range_blacklist.update(["0.0.0.0", "::"]) + + try: + self.ip_range_whitelist = IPSet(config.get("ip_range_whitelist", ())) + except Exception as e: + raise ConfigError("Invalid range(s) provided in ip_range_whitelist.") from e + + # The federation_ip_range_blacklist is used for backwards-compatibility + # and only applies to federation and identity servers. If it is not given, + # default to ip_range_blacklist. + federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", ip_range_blacklist + ) + try: + self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist) + except Exception as e: + raise ConfigError( + "Invalid range(s) provided in federation_ip_range_blacklist." + ) from e + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": self.public_baseurl += "/" @@ -561,6 +622,10 @@ class ServerConfig(Config): def generate_config_section( self, server_name, data_dir_path, open_private_ports, listeners, **kwargs ): + ip_range_blacklist = "\n".join( + " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST + ) + _, bind_port = parse_and_validate_server_name(server_name) if bind_port is not None: unsecure_port = bind_port - 400 @@ -752,6 +817,33 @@ class ServerConfig(Config): # #enable_search: false + # Prevent outgoing requests from being sent to the following blacklisted IP address + # CIDR ranges. If this option is not specified then it defaults to private IP + # address ranges (see the example below). + # + # The blacklist applies to the outbound requests for federation, identity servers, + # push servers, and for checking key validity for third-party invite events. + # + # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly + # listed here, since they correspond to unroutable addresses.) + # + # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. + # + #ip_range_blacklist: +%(ip_range_blacklist)s + + # List of IP address CIDR ranges that should be allowed for federation, + # identity servers, push servers, and for checking key validity for + # third-party invite events. This is useful for specifying exceptions to + # wide-ranging blacklisted target IP ranges - e.g. for communication with + # a push server only visible in your network. + # + # This whitelist overrides ip_range_blacklist and defaults to an empty + # list. + # + #ip_range_whitelist: + # - '192.168.1.1' + # List of ports that Synapse should listen on, their purpose and their # configuration. # diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 3d067d29db..3d05abc158 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py
@@ -33,13 +33,14 @@ class SpamCheckerConfig(Config): # spam checker, and thus was simply a dictionary with module # and config keys. Support this old behaviour by checking # to see if the option resolves to a dictionary - self.spam_checkers.append(load_module(spam_checkers)) + self.spam_checkers.append(load_module(spam_checkers, ("spam_checker",))) elif isinstance(spam_checkers, list): - for spam_checker in spam_checkers: + for i, spam_checker in enumerate(spam_checkers): + config_path = ("spam_checker", "<item %i>" % i) if not isinstance(spam_checker, dict): - raise ConfigError("spam_checker syntax is incorrect") + raise ConfigError("expected a mapping", config_path) - self.spam_checkers.append(load_module(spam_checker)) + self.spam_checkers.append(load_module(spam_checker, config_path)) else: raise ConfigError("spam_checker syntax is incorrect") diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 4427676167..93bbd40937 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py
@@ -93,11 +93,8 @@ class SSOConfig(Config): # - https://my.custom.client/ # Directory in which Synapse will try to find the template files below. - # If not set, default templates from within the Synapse package will be used. - # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # If not set, or the files named below are not found within the template + # directory, default templates from within the Synapse package will be used. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index 10a99c792e..c04e1c4e07 100644 --- a/synapse/config/third_party_event_rules.py +++ b/synapse/config/third_party_event_rules.py
@@ -26,7 +26,9 @@ class ThirdPartyRulesConfig(Config): provider = config.get("third_party_event_rules", None) if provider is not None: - self.third_party_event_rules = load_module(provider) + self.third_party_event_rules = load_module( + provider, ("third_party_event_rules",) + ) def generate_config_section(self, **kwargs): return """\ diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 9ddb8b546b..ad37b93c02 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py
@@ -18,7 +18,7 @@ import os import warnings from datetime import datetime from hashlib import sha256 -from typing import List +from typing import List, Optional from unpaddedbase64 import encode_base64 @@ -177,8 +177,8 @@ class TlsConfig(Config): "use_insecure_ssl_client_just_for_testing_do_not_use" ) - self.tls_certificate = None - self.tls_private_key = None + self.tls_certificate = None # type: Optional[crypto.X509] + self.tls_private_key = None # type: Optional[crypto.PKey] def is_disk_cert_valid(self, allow_self_signed=True): """ @@ -226,12 +226,12 @@ class TlsConfig(Config): days_remaining = (expires_on - now).days return days_remaining - def read_certificate_from_disk(self, require_cert_and_key): + def read_certificate_from_disk(self, require_cert_and_key: bool): """ Read the certificates and private key from disk. Args: - require_cert_and_key (bool): set to True to throw an error if the certificate + require_cert_and_key: set to True to throw an error if the certificate and key file are not given """ if require_cert_and_key: @@ -479,13 +479,13 @@ class TlsConfig(Config): } ) - def read_tls_certificate(self): + def read_tls_certificate(self) -> crypto.X509: """Reads the TLS certificate from the configured file, and returns it Also checks if it is self-signed, and warns if so Returns: - OpenSSL.crypto.X509: the certificate + The certificate """ cert_path = self.tls_certificate_file logger.info("Loading TLS certificate from %s", cert_path) @@ -504,11 +504,11 @@ class TlsConfig(Config): return cert - def read_tls_private_key(self): + def read_tls_private_key(self) -> crypto.PKey: """Reads the TLS private key from the configured file, and returns it Returns: - OpenSSL.crypto.PKey: the private key + The private key """ private_key_path = self.tls_private_key_file logger.info("Loading TLS key from %s", private_key_path) diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 8be1346113..0c1a854f09 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py
@@ -67,7 +67,7 @@ class TracerConfig(Config): # This is a list of regexes which are matched against the server_name of the # homeserver. # - # By defult, it is empty, so no servers are matched. + # By default, it is empty, so no servers are matched. # #homeserver_whitelist: # - ".*" diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index f23e42cdf9..7ca9efec52 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py
@@ -85,6 +85,9 @@ class WorkerConfig(Config): # The port on the main synapse for HTTP replication endpoint self.worker_replication_http_port = config.get("worker_replication_http_port") + # The shared secret used for authentication when connecting to the main synapse. + self.worker_replication_secret = config.get("worker_replication_secret", None) + self.worker_name = config.get("worker_name", self.worker_app) self.worker_main_http_uri = config.get("worker_main_http_uri", None) @@ -132,6 +135,19 @@ class WorkerConfig(Config): self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) + # Whether this worker should run background tasks or not. + # + # As a note for developers, the background tasks guarded by this should + # be able to run on only a single instance (meaning that they don't + # depend on any in-memory state of a particular worker). + # + # No effort is made to ensure only a single instance of these tasks is + # running. + background_tasks_instance = config.get("run_background_tasks_on") or "master" + self.run_background_tasks = ( + self.worker_name is None and background_tasks_instance == "master" + ) or self.worker_name == background_tasks_instance + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ ## Workers ## @@ -167,6 +183,18 @@ class WorkerConfig(Config): #stream_writers: # events: worker1 # typing: worker1 + + # The worker that is used to run background tasks (e.g. cleaning up expired + # data). If not provided this defaults to the main process. + # + #run_background_tasks_on: worker1 + + # A shared secret used by the replication APIs to authenticate HTTP requests + # from workers. + # + # By default this is unused and traffic is not authenticated. + # + #worker_replication_secret: "" """ def read_arguments(self, args): diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 79668a402e..74b67b230a 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py
@@ -149,7 +149,7 @@ class FederationPolicyForHTTPS: return SSLClientConnectionCreator(host, ssl_context, should_verify) def creatorForNetloc(self, hostname, port): - """Implements the IPolicyForHTTPS interace so that this can be passed + """Implements the IPolicyForHTTPS interface so that this can be passed directly to agents. """ return self.get_options(hostname) @@ -227,7 +227,7 @@ class ConnectionVerifier: # This code is based on twisted.internet.ssl.ClientTLSOptions. - def __init__(self, hostname: bytes, verify_certs): + def __init__(self, hostname: bytes, verify_certs: bool): self._verify_certs = verify_certs _decoded = hostname.decode("ascii") diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 0422c43fab..8fb116ae18 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py
@@ -18,7 +18,7 @@ import collections.abc import hashlib import logging -from typing import Dict +from typing import Any, Callable, Dict, Tuple from canonicaljson import encode_canonical_json from signedjson.sign import sign_json @@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion +from synapse.events import EventBase from synapse.events.utils import prune_event, prune_event_dict from synapse.types import JsonDict logger = logging.getLogger(__name__) +Hasher = Callable[[bytes], "hashlib._Hash"] -def check_event_content_hash(event, hash_algorithm=hashlib.sha256): + +def check_event_content_hash( + event: EventBase, hash_algorithm: Hasher = hashlib.sha256 +) -> bool: """Check whether the hash for this PDU matches the contents""" name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm) logger.debug( @@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): return message_hash_bytes == expected_hash -def compute_content_hash(event_dict, hash_algorithm): +def compute_content_hash( + event_dict: Dict[str, Any], hash_algorithm: Hasher +) -> Tuple[str, bytes]: """Compute the content hash of an event, which is the hash of the unredacted event. Args: - event_dict (dict): The unredacted event as a dict + event_dict: The unredacted event as a dict hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use to hash the event Returns: - tuple[str, bytes]: A tuple of the name of hash and the hash as raw - bytes. + A tuple of the name of hash and the hash as raw bytes. """ event_dict = dict(event_dict) event_dict.pop("age_ts", None) @@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm): return hashed.name, hashed.digest() -def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): +def compute_event_reference_hash( + event, hash_algorithm: Hasher = hashlib.sha256 +) -> Tuple[str, bytes]: """Computes the event reference hash. This is the hash of the redacted event. Args: - event (FrozenEvent) + event hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use to hash the event Returns: - tuple[str, bytes]: A tuple of the name of hash and the hash as raw - bytes. + A tuple of the name of hash and the hash as raw bytes. """ tmp_event = prune_event(event) event_dict = tmp_event.get_pdu_json() @@ -156,7 +163,7 @@ def add_hashes_and_signatures( event_dict: JsonDict, signature_name: str, signing_key: SigningKey, -): +) -> None: """Add content hash and sign the event Args: diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index c04ad77cf9..902128a23c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py
@@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging import urllib from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import attr from signedjson.key import ( @@ -40,6 +42,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.config.key import TrustedKeyServer from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, @@ -47,11 +50,15 @@ from synapse.logging.context import ( run_in_background, ) from synapse.storage.keys import FetchKeyResult +from synapse.types import JsonDict from synapse.util import unwrapFirstError from synapse.util.async_helpers import yieldable_gather_results from synapse.util.metrics import Measure from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -61,16 +68,17 @@ class VerifyJsonRequest: A request to verify a JSON object. Attributes: - server_name(str): The name of the server to verify against. - - key_ids(set[str]): The set of key_ids to that could be used to verify the - JSON object + server_name: The name of the server to verify against. - json_object(dict): The JSON object to verify. + json_object: The JSON object to verify. - minimum_valid_until_ts (int): time at which we require the signing key to + minimum_valid_until_ts: time at which we require the signing key to be valid. (0 implies we don't care) + request_name: The name of the request. + + key_ids: The set of key_ids to that could be used to verify the JSON object + key_ready (Deferred[str, str, nacl.signing.VerifyKey]): A deferred (server_name, key_id, verify_key) tuple that resolves when a verify key has been fetched. The deferreds' callbacks are run with no @@ -80,12 +88,12 @@ class VerifyJsonRequest: errbacks with an M_UNAUTHORIZED SynapseError. """ - server_name = attr.ib() - json_object = attr.ib() - minimum_valid_until_ts = attr.ib() - request_name = attr.ib() - key_ids = attr.ib(init=False) - key_ready = attr.ib(default=attr.Factory(defer.Deferred)) + server_name = attr.ib(type=str) + json_object = attr.ib(type=JsonDict) + minimum_valid_until_ts = attr.ib(type=int) + request_name = attr.ib(type=str) + key_ids = attr.ib(init=False, type=List[str]) + key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred) def __attrs_post_init__(self): self.key_ids = signature_ids(self.json_object, self.server_name) @@ -96,7 +104,9 @@ class KeyLookupError(ValueError): class Keyring: - def __init__(self, hs, key_fetchers=None): + def __init__( + self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None + ): self.clock = hs.get_clock() if key_fetchers is None: @@ -112,22 +122,26 @@ class Keyring: # completes. # # These are regular, logcontext-agnostic Deferreds. - self.key_downloads = {} + self.key_downloads = {} # type: Dict[str, defer.Deferred] def verify_json_for_server( - self, server_name, json_object, validity_time, request_name - ): + self, + server_name: str, + json_object: JsonDict, + validity_time: int, + request_name: str, + ) -> defer.Deferred: """Verify that a JSON object has been signed by a given server Args: - server_name (str): name of the server which must have signed this object + server_name: name of the server which must have signed this object - json_object (dict): object to be checked + json_object: object to be checked - validity_time (int): timestamp at which we require the signing key to + validity_time: timestamp at which we require the signing key to be valid. (0 implies we don't care) - request_name (str): an identifier for this json object (eg, an event id) + request_name: an identifier for this json object (eg, an event id) for logging. Returns: @@ -138,12 +152,14 @@ class Keyring: requests = (req,) return make_deferred_yieldable(self._verify_objects(requests)[0]) - def verify_json_objects_for_server(self, server_and_json): + def verify_json_objects_for_server( + self, server_and_json: Iterable[Tuple[str, dict, int, str]] + ) -> List[defer.Deferred]: """Bulk verifies signatures of json objects, bulk fetching keys as necessary. Args: - server_and_json (iterable[Tuple[str, dict, int, str]): + server_and_json: Iterable of (server_name, json_object, validity_time, request_name) tuples. @@ -164,13 +180,14 @@ class Keyring: for server_name, json_object, validity_time, request_name in server_and_json ) - def _verify_objects(self, verify_requests): + def _verify_objects( + self, verify_requests: Iterable[VerifyJsonRequest] + ) -> List[defer.Deferred]: """Does the work of verify_json_[objects_]for_server Args: - verify_requests (iterable[VerifyJsonRequest]): - Iterable of verification requests. + verify_requests: Iterable of verification requests. Returns: List<Deferred[None]>: for each input item, a deferred indicating success @@ -182,7 +199,7 @@ class Keyring: key_lookups = [] handle = preserve_fn(_handle_key_deferred) - def process(verify_request): + def process(verify_request: VerifyJsonRequest) -> defer.Deferred: """Process an entry in the request list Adds a key request to key_lookups, and returns a deferred which @@ -222,18 +239,20 @@ class Keyring: return results - async def _start_key_lookups(self, verify_requests): + async def _start_key_lookups( + self, verify_requests: List[VerifyJsonRequest] + ) -> None: """Sets off the key fetches for each verify request Once each fetch completes, verify_request.key_ready will be resolved. Args: - verify_requests (List[VerifyJsonRequest]): + verify_requests: """ try: # map from server name to a set of outstanding request ids - server_to_request_ids = {} + server_to_request_ids = {} # type: Dict[str, Set[int]] for verify_request in verify_requests: server_name = verify_request.server_name @@ -275,11 +294,11 @@ class Keyring: except Exception: logger.exception("Error starting key lookups") - async def wait_for_previous_lookups(self, server_names) -> None: + async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None: """Waits for any previous key lookups for the given servers to finish. Args: - server_names (Iterable[str]): list of servers which we want to look up + server_names: list of servers which we want to look up Returns: Resolves once all key lookups for the given servers have @@ -304,7 +323,7 @@ class Keyring: loop_count += 1 - def _get_server_verify_keys(self, verify_requests): + def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None: """Tries to find at least one key for each verify request For each verify_request, verify_request.key_ready is called back with @@ -312,7 +331,7 @@ class Keyring: with a SynapseError if none of the keys are found. Args: - verify_requests (list[VerifyJsonRequest]): list of verify requests + verify_requests: list of verify requests """ remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} @@ -366,17 +385,19 @@ class Keyring: run_in_background(do_iterations) - async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): + async def _attempt_key_fetches_with_fetcher( + self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest] + ): """Use a key fetcher to attempt to satisfy some key requests Args: - fetcher (KeyFetcher): fetcher to use to fetch the keys - remaining_requests (set[VerifyJsonRequest]): outstanding key requests. + fetcher: fetcher to use to fetch the keys + remaining_requests: outstanding key requests. Any successfully-completed requests will be removed from the list. """ - # dict[str, dict[str, int]]: keys to fetch. + # The keys to fetch. # server_name -> key_id -> min_valid_ts - missing_keys = defaultdict(dict) + missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]] for verify_request in remaining_requests: # any completed requests should already have been removed @@ -438,16 +459,18 @@ class Keyring: remaining_requests.difference_update(completed) -class KeyFetcher: - async def get_keys(self, keys_to_fetch): +class KeyFetcher(metaclass=abc.ABCMeta): + @abc.abstractmethod + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, dict[str, int]]): + keys_to_fetch: the keys to be fetched. server_name -> key_id -> min_valid_ts Returns: - Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: - map from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult """ raise NotImplementedError @@ -455,31 +478,35 @@ class KeyFetcher: class StoreKeyFetcher(KeyFetcher): """KeyFetcher impl which fetches keys from our data store""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """see KeyFetcher.get_keys""" - keys_to_fetch = ( + key_ids_to_fetch = ( (server_name, key_id) for server_name, keys_for_server in keys_to_fetch.items() for key_id in keys_for_server.keys() ) - res = await self.store.get_server_verify_keys(keys_to_fetch) - keys = {} + res = await self.store.get_server_verify_keys(key_ids_to_fetch) + keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key return keys -class BaseV2KeyFetcher: - def __init__(self, hs): +class BaseV2KeyFetcher(KeyFetcher): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.config = hs.get_config() - async def process_v2_response(self, from_server, response_json, time_added_ms): + async def process_v2_response( + self, from_server: str, response_json: JsonDict, time_added_ms: int + ) -> Dict[str, FetchKeyResult]: """Parse a 'Server Keys' structure from the result of a /key request This is used to parse either the entirety of the response from @@ -493,16 +520,16 @@ class BaseV2KeyFetcher: to /_matrix/key/v2/query. Args: - from_server (str): the name of the server producing this result: either + from_server: the name of the server producing this result: either the origin server for a /_matrix/key/v2/server request, or the notary for a /_matrix/key/v2/query. - response_json (dict): the json-decoded Server Keys response object + response_json: the json-decoded Server Keys response object - time_added_ms (int): the timestamp to record in server_keys_json + time_added_ms: the timestamp to record in server_keys_json Returns: - Deferred[dict[str, FetchKeyResult]]: map from key_id to result object + Map from key_id to result object """ ts_valid_until_ms = response_json["valid_until_ts"] @@ -575,21 +602,22 @@ class BaseV2KeyFetcher: class PerspectivesKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the "perspectives" servers""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() self.key_servers = self.config.key_servers - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """see KeyFetcher.get_keys""" - async def get_key(key_server): + async def get_key(key_server: TrustedKeyServer) -> Dict: try: - result = await self.get_server_verify_key_v2_indirect( + return await self.get_server_verify_key_v2_indirect( keys_to_fetch, key_server ) - return result except KeyLookupError as e: logger.warning( "Key lookup failed from %r: %s", key_server.server_name, e @@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ).addErrback(unwrapFirstError) ) - union_of_keys = {} + union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] for result in results: for server_name, keys in result.items(): union_of_keys.setdefault(server_name, {}).update(keys) return union_of_keys - async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): + async def get_server_verify_key_v2_indirect( + self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, dict[str, int]]): + keys_to_fetch: the keys to be fetched. server_name -> key_id -> min_valid_ts - key_server (synapse.config.key.TrustedKeyServer): notary server to query for - the keys + key_server: notary server to query for the keys Returns: - dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map - from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult Raises: KeyLookupError if there was an error processing the entire response from @@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) - keys = {} - added_keys = [] + keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] + added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]] time_now_ms = self.clock.time_msec() + assert isinstance(query_response, dict) for response in query_response["server_keys"]: # do this first, so that we can give useful errors thereafter server_name = response.get("server_name") @@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return keys - def _validate_perspectives_response(self, key_server, response): + def _validate_perspectives_response( + self, key_server: TrustedKeyServer, response: JsonDict + ) -> None: """Optionally check the signature on the result of a /key/query request Args: - key_server (synapse.config.key.TrustedKeyServer): the notary server that - produced this result + key_server: the notary server that produced this result - response (dict): the json-decoded Server Keys response object + response: the json-decoded Server Keys response object """ perspective_name = key_server.server_name perspective_keys = key_server.verify_keys @@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): class ServerKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the origin servers""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() - async def get_keys(self, keys_to_fetch): + async def get_keys( + self, keys_to_fetch: Dict[str, Dict[str, int]] + ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: - keys_to_fetch (dict[str, iterable[str]]): + keys_to_fetch: the keys to be fetched. server_name -> key_ids Returns: - dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]: - map from server_name -> key_id -> FetchKeyResult + Map from server_name -> key_id -> FetchKeyResult """ results = {} - async def get_key(key_to_fetch_item): + async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None: server_name, key_ids = key_to_fetch_item try: keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) @@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher): await yieldable_gather_results(get_key, keys_to_fetch.items()) return results - async def get_server_verify_key_v2_direct(self, server_name, key_ids): + async def get_server_verify_key_v2_direct( + self, server_name: str, key_ids: Iterable[str] + ) -> Dict[str, FetchKeyResult]: """ Args: - server_name (str): - key_ids (iterable[str]): + server_name: + key_ids: Returns: - dict[str, FetchKeyResult]: map from key ID to lookup result + Map from key ID to lookup result Raises: KeyLookupError if there was a problem making the lookup """ - keys = {} # type: dict[str, FetchKeyResult] + keys = {} # type: Dict[str, FetchKeyResult] for requested_key_id in key_ids: # we may have found this key as a side-effect of asking for another. @@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except HttpResponseException as e: raise KeyLookupError("Remote server returned an error: %s" % (e,)) + assert isinstance(response, dict) if response["server_name"] != server_name: raise KeyLookupError( "Expected a response for server %r not %r" @@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher): return keys -async def _handle_key_deferred(verify_request) -> None: +async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None: """Waits for the key to become available, and then performs a verification Args: - verify_request (VerifyJsonRequest): + verify_request: Raises: SynapseError if there was a problem performing the verification diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 8c907ad596..56f8dc9caf 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py
@@ -446,6 +446,8 @@ def check_redaction( if room_version_obj.event_format == EventFormatVersions.V1: redacter_domain = get_domain_from_id(event.event_id) + if not isinstance(event.redacts, str): + return False redactee_domain = get_domain_from_id(event.redacts) if redacter_domain == redactee_domain: return True diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index dc49df0812..8028663fa8 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py
@@ -59,7 +59,7 @@ class DictProperty: # # To exclude the KeyError from the traceback, we explicitly # 'raise from e1.__context__' (which is better than 'raise from None', - # becuase that would omit any *earlier* exceptions). + # because that would omit any *earlier* exceptions). # raise AttributeError( "'%s' has no '%s' property" % (type(instance), self.key) @@ -97,13 +97,16 @@ class DefaultDictProperty(DictProperty): class _EventInternalMetadata: - __slots__ = ["_dict"] + __slots__ = ["_dict", "stream_ordering"] def __init__(self, internal_metadata_dict: JsonDict): # we have to copy the dict, because it turns out that the same dict is # reused. TODO: fix that self._dict = dict(internal_metadata_dict) + # the stream ordering of this event. None, until it has been persisted. + self.stream_ordering = None # type: Optional[int] + outlier = DictProperty("outlier") # type: bool out_of_band_membership = DictProperty("out_of_band_membership") # type: bool send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str @@ -113,7 +116,6 @@ class _EventInternalMetadata: redacted = DictProperty("redacted") # type: bool txn_id = DictProperty("txn_id") # type: str token_id = DictProperty("token_id") # type: str - stream_ordering = DictProperty("stream_ordering") # type: int # XXX: These are set by StreamWorkerStore._set_before_and_after. # I'm pretty sure that these are never persisted to the database, so shouldn't @@ -310,6 +312,12 @@ class EventBase(metaclass=abc.ABCMeta): """ return [e for e, _ in self.auth_events] + def freeze(self): + """'Freeze' the event dict, so it cannot be modified by accident""" + + # this will be a no-op if the event dict is already frozen. + self._dict = freeze(self._dict) + class FrozenEvent(EventBase): format_version = EventFormatVersions.V1 # All events of this type are V1 @@ -360,7 +368,7 @@ class FrozenEvent(EventBase): return self.__repr__() def __repr__(self): - return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % ( + return "<FrozenEvent event_id=%r, type=%r, state_key=%r>" % ( self.get("event_id", None), self.get("type", None), self.get("state_key", None), @@ -443,7 +451,7 @@ class FrozenEventV2(EventBase): return self.__repr__() def __repr__(self): - return "<%s event_id='%s', type='%s', state_key='%s'>" % ( + return "<%s event_id=%r, type=%r, state_key=%r>" % ( self.__class__.__name__, self.event_id, self.get("type", None), diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index b6c47be646..07df258e6e 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py
@@ -97,32 +97,37 @@ class EventBuilder: def is_state(self): return self._state_key is not None - async def build(self, prev_event_ids: List[str]) -> EventBase: + async def build( + self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]], + ) -> EventBase: """Transform into a fully signed and hashed event Args: prev_event_ids: The event IDs to use as the prev events + auth_event_ids: The event IDs to use as the auth events. + Should normally be set to None, which will cause them to be calculated + based on the room state at the prev_events. Returns: The signed and hashed event. """ - - state_ids = await self._state.get_current_state_ids( - self.room_id, prev_event_ids - ) - auth_ids = self._auth.compute_auth_events(self, state_ids) + if auth_event_ids is None: + state_ids = await self._state.get_current_state_ids( + self.room_id, prev_event_ids + ) + auth_event_ids = self._auth.compute_auth_events(self, state_ids) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: # The types of auth/prev events changes between event versions. auth_events = await self._store.add_event_hashes( - auth_ids + auth_event_ids ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] prev_events = await self._store.add_event_hashes( prev_event_ids ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] else: - auth_events = auth_ids + auth_events = auth_event_ids prev_events = prev_event_ids old_depth = await self._store.get_max_depth_of(prev_event_ids) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index b0fc859a47..e7e3a7b9a4 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py
@@ -15,31 +15,34 @@ # limitations under the License. import inspect -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi +from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import Collection +from synapse.util.async_helpers import maybe_awaitable -MYPY = False -if MYPY: +if TYPE_CHECKING: + import synapse.events import synapse.server class SpamChecker: def __init__(self, hs: "synapse.server.HomeServer"): self.spam_checkers = [] # type: List[Any] + api = hs.get_module_api() for module, config in hs.config.spam_checkers: # Older spam checkers don't accept the `api` argument, so we # try and detect support. spam_args = inspect.getfullargspec(module) if "api" in spam_args.args: - api = SpamCheckerApi(hs) self.spam_checkers.append(module(config=config, api=api)) else: self.spam_checkers.append(module(config=config)) - def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool: + async def check_event_for_spam( + self, event: "synapse.events.EventBase" + ) -> Union[bool, str]: """Checks if a given event is considered "spammy" by this server. If the server considers an event spammy, then it will be rejected if @@ -50,15 +53,16 @@ class SpamChecker: event: the event to be checked Returns: - True if the event is spammy. + True or a string if the event is spammy. If a string is returned it + will be used as the error message returned to the user. """ for spam_checker in self.spam_checkers: - if spam_checker.check_event_for_spam(event): + if await maybe_awaitable(spam_checker.check_event_for_spam(event)): return True return False - def user_may_invite( + async def user_may_invite( self, inviter_userid: str, invitee_userid: str, room_id: str ) -> bool: """Checks if a given user may send an invite @@ -75,14 +79,18 @@ class SpamChecker: """ for spam_checker in self.spam_checkers: if ( - spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) + await maybe_awaitable( + spam_checker.user_may_invite( + inviter_userid, invitee_userid, room_id + ) + ) is False ): return False return True - def user_may_create_room(self, userid: str) -> bool: + async def user_may_create_room(self, userid: str) -> bool: """Checks if a given user may create a room If this method returns false, the creation request will be rejected. @@ -94,12 +102,15 @@ class SpamChecker: True if the user may create a room, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_create_room(userid) is False: + if ( + await maybe_awaitable(spam_checker.user_may_create_room(userid)) + is False + ): return False return True - def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: + async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool: """Checks if a given user may create a room alias If this method returns false, the association request will be rejected. @@ -112,12 +123,17 @@ class SpamChecker: True if the user may create a room alias, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_create_room_alias(userid, room_alias) is False: + if ( + await maybe_awaitable( + spam_checker.user_may_create_room_alias(userid, room_alias) + ) + is False + ): return False return True - def user_may_publish_room(self, userid: str, room_id: str) -> bool: + async def user_may_publish_room(self, userid: str, room_id: str) -> bool: """Checks if a given user may publish a room to the directory If this method returns false, the publish request will be rejected. @@ -130,12 +146,17 @@ class SpamChecker: True if the user may publish the room, otherwise False """ for spam_checker in self.spam_checkers: - if spam_checker.user_may_publish_room(userid, room_id) is False: + if ( + await maybe_awaitable( + spam_checker.user_may_publish_room(userid, room_id) + ) + is False + ): return False return True - def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: + async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. If the server considers a username spammy, then it will not be included in @@ -157,12 +178,12 @@ class SpamChecker: if checker: # Make a copy of the user profile object to ensure the spam checker # cannot modify it. - if checker(user_profile.copy()): + if await maybe_awaitable(checker(user_profile.copy())): return True return False - def check_registration_for_spam( + async def check_registration_for_spam( self, email_threepid: Optional[dict], username: Optional[str], @@ -185,7 +206,9 @@ class SpamChecker: # spam checker checker = getattr(spam_checker, "check_registration_for_spam", None) if checker: - behaviour = checker(email_threepid, username, request_info) + behaviour = await maybe_awaitable( + checker(email_threepid, username, request_info) + ) assert isinstance(behaviour, RegistrationBehaviour) if behaviour != RegistrationBehaviour.ALLOW: return behaviour diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 9d5310851c..77fbd3f68a 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py
@@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, Union + from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import Requester +from synapse.types import Requester, StateMap class ThirdPartyEventRules: @@ -38,20 +40,25 @@ class ThirdPartyEventRules: if module is not None: self.third_party_rules = module( - config=config, http_client=hs.get_simple_http_client() + config=config, module_api=hs.get_module_api(), ) async def check_event_allowed( self, event: EventBase, context: EventContext - ) -> bool: + ) -> Union[bool, dict]: """Check if a provided event should be allowed in the given context. + The module can return: + * True: the event is allowed. + * False: the event is not allowed, and should be rejected with M_FORBIDDEN. + * a dict: replacement event data. + Args: event: The event to be checked. context: The context of the event. Returns: - True if the event should be allowed, False if not. + The result from the ThirdPartyRules module, as above """ if self.third_party_rules is None: return True @@ -59,12 +66,15 @@ class ThirdPartyEventRules: prev_state_ids = await context.get_prev_state_ids() # Retrieve the state events from the database. - state_events = {} - for key, event_id in prev_state_ids.items(): - state_events[key] = await self.store.get_event(event_id, allow_none=True) + events = await self.store.get_events(prev_state_ids.values()) + state_events = {(ev.type, ev.state_key): ev for ev in events.values()} - ret = await self.third_party_rules.check_event_allowed(event, state_events) - return ret + # Ensure that the event is frozen, to make sure that the module is not tempted + # to try to modify it. Any attempt to modify it at this point will invalidate + # the hashes and signatures. + event.freeze() + + return await self.third_party_rules.check_event_allowed(event, state_events) async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool @@ -106,6 +116,48 @@ class ThirdPartyEventRules: if self.third_party_rules is None: return True + state_events = await self._get_state_map_for_room(room_id) + + ret = await self.third_party_rules.check_threepid_can_be_invited( + medium, address, state_events + ) + return ret + + async def check_visibility_can_be_modified( + self, room_id: str, new_visibility: str + ) -> bool: + """Check if a room is allowed to be published to, or removed from, the public room + list. + + Args: + room_id: The ID of the room. + new_visibility: The new visibility state. Either "public" or "private". + + Returns: + True if the room's visibility can be modified, False if not. + """ + if self.third_party_rules is None: + return True + + check_func = getattr( + self.third_party_rules, "check_visibility_can_be_modified", None + ) + if not check_func or not isinstance(check_func, Callable): + return True + + state_events = await self._get_state_map_for_room(room_id) + + return await check_func(room_id, state_events, new_visibility) + + async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: + """Given a room ID, return the state events of that room. + + Args: + room_id: The ID of the room. + + Returns: + A dict mapping (event type, state key) to state event. + """ state_ids = await self.store.get_filtered_current_state_ids(room_id) room_state_events = await self.store.get_events(state_ids.values()) @@ -113,7 +165,4 @@ class ThirdPartyEventRules: for key, event_id in state_ids.items(): state_events[key] = room_state_events[event_id] - ret = await self.third_party_rules.check_threepid_can_be_invited( - medium, address, state_events - ) - return ret + return state_events diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 32c73d3413..14f7f1156f 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py
@@ -49,6 +49,11 @@ def prune_event(event: EventBase) -> EventBase: pruned_event_dict, event.room_version, event.internal_metadata.get_dict() ) + # copy the internal fields + pruned_event.internal_metadata.stream_ordering = ( + event.internal_metadata.stream_ordering + ) + # Mark the event as redacted pruned_event.internal_metadata.redacted = True @@ -175,7 +180,7 @@ def only_fields(dictionary, fields): in 'fields'. If there are no event fields specified then all fields are included. - The entries may include '.' charaters to indicate sub-fields. + The entries may include '.' characters to indicate sub-fields. So ['content.body'] will include the 'body' field of the 'content' object. A literal '.' character in a field name may be escaped using a '\'. diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 9df35b54ba..f8f3b1a31e 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py
@@ -13,20 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions +from synapse.config.homeserver import HomeServerConfig +from synapse.events import EventBase +from synapse.events.builder import EventBuilder from synapse.events.utils import validate_canonicaljson +from synapse.federation.federation_server import server_matches_acl_event from synapse.types import EventID, RoomID, UserID class EventValidator: - def validate_new(self, event, config): + def validate_new(self, event: EventBase, config: HomeServerConfig): """Validates the event has roughly the right format Args: - event (FrozenEvent): The event to validate. - config (Config): The homeserver's configuration. + event: The event to validate. + config: The homeserver's configuration. """ self.validate_builder(event) @@ -76,13 +82,22 @@ class EventValidator: if event.type == EventTypes.Retention: self._validate_retention(event) - def _validate_retention(self, event): + if event.type == EventTypes.ServerACL: + if not server_matches_acl_event(config.server_name, event): + raise SynapseError( + 400, "Can't create an ACL event that denies the local server" + ) + + def _validate_retention(self, event: EventBase): """Checks that an event that defines the retention policy for a room respects the format enforced by the spec. Args: - event (FrozenEvent): The event to validate. + event: The event to validate. """ + if not event.is_state(): + raise SynapseError(code=400, msg="must be a state event") + min_lifetime = event.content.get("min_lifetime") max_lifetime = event.content.get("max_lifetime") @@ -113,13 +128,10 @@ class EventValidator: errcode=Codes.BAD_JSON, ) - def validate_builder(self, event): + def validate_builder(self, event: Union[EventBase, EventBuilder]): """Validates that the builder/event has roughly the right format. Only checks values that we expect a proto event to have, rather than all the fields an event would have - - Args: - event (EventBuilder|FrozenEvent) """ strings = ["room_id", "sender", "type"] diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 38aa47963f..383737520a 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py
@@ -78,6 +78,7 @@ class FederationBase: ctx = current_context() + @defer.inlineCallbacks def callback(_, pdu: EventBase): with PreserveLoggingContext(ctx): if not check_event_content_hash(pdu): @@ -105,7 +106,11 @@ class FederationBase: ) return redacted_event - if self.spam_checker.check_event_for_spam(pdu): + result = yield defer.ensureDeferred( + self.spam_checker.check_event_for_spam(pdu) + ) + + if result: logger.warning( "Event contains spam, redacting %s: %s", pdu.event_id, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 24329dd0e3..35e345ce70 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -22,7 +22,6 @@ from typing import ( Callable, Dict, List, - Match, Optional, Tuple, Union, @@ -50,6 +49,7 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction from synapse.http.endpoint import parse_server_name +from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -100,10 +100,15 @@ class FederationServer(FederationBase): super().__init__(hs) self.auth = hs.get_auth() - self.handler = hs.get_handlers().federation_handler + self.handler = hs.get_federation_handler() self.state = hs.get_state_handler() self.device_handler = hs.get_device_handler() + + # Ensure the following handlers are loaded since they register callbacks + # with FederationHandlerRegistry. + hs.get_directory_handler() + self._federation_ratelimiter = hs.get_federation_ratelimiter() self._server_linearizer = Linearizer("fed_server") @@ -112,7 +117,7 @@ class FederationServer(FederationBase): # We cache results for transaction with the same ID self._transaction_resp_cache = ResponseCache( hs, "fed_txn_handler", timeout_ms=30000 - ) + ) # type: ResponseCache[Tuple[str, str]] self.transaction_actions = TransactionActions(self.store) @@ -120,10 +125,12 @@ class FederationServer(FederationBase): # We cache responses to state queries, as they take a while and often # come in waves. - self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) + self._state_resp_cache = ResponseCache( + hs, "state_resp", timeout_ms=30000 + ) # type: ResponseCache[Tuple[str, str]] self._state_ids_resp_cache = ResponseCache( hs, "state_ids_resp", timeout_ms=30000 - ) + ) # type: ResponseCache[Tuple[str, str]] self._federation_metrics_domains = ( hs.get_config().federation.federation_metrics_domains @@ -385,7 +392,7 @@ class FederationServer(FederationBase): TRANSACTION_CONCURRENCY_LIMIT, ) - async def on_context_state_request( + async def on_room_state_request( self, origin: str, room_id: str, event_id: str ) -> Tuple[int, Dict[str, Any]]: origin_host, _ = parse_server_name(origin) @@ -508,11 +515,12 @@ class FederationServer(FederationBase): return {"event": ret_pdu.get_pdu_json(time_now)} async def on_send_join_request( - self, origin: str, content: JsonDict, room_id: str + self, origin: str, content: JsonDict ) -> Dict[str, Any]: logger.debug("on_send_join_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + assert_params_in_dict(content, ["room_id"]) + room_version = await self.store.get_room_version(content["room_id"]) pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) @@ -541,12 +549,11 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - async def on_send_leave_request( - self, origin: str, content: JsonDict, room_id: str - ) -> dict: + async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict: logger.debug("on_send_leave_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + assert_params_in_dict(content, ["room_id"]) + room_version = await self.store.get_room_version(content["room_id"]) pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) @@ -742,12 +749,8 @@ class FederationServer(FederationBase): ) return ret - async def on_exchange_third_party_invite_request( - self, room_id: str, event_dict: Dict - ): - ret = await self.handler.on_exchange_third_party_invite_request( - room_id, event_dict - ) + async def on_exchange_third_party_invite_request(self, event_dict: Dict): + ret = await self.handler.on_exchange_third_party_invite_request(event_dict) return ret async def check_server_matches_acl(self, server_name: str, room_id: str): @@ -825,14 +828,14 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: return False -def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: +def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool: if not isinstance(acl_entry, str): logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) return False regex = glob_to_regex(acl_entry) - return regex.match(server_name) + return bool(regex.match(server_name)) class FederationHandlerRegistry: @@ -842,7 +845,6 @@ class FederationHandlerRegistry: def __init__(self, hs: "HomeServer"): self.config = hs.config - self.http_client = hs.get_simple_http_client() self.clock = hs.get_clock() self._instance_name = hs.get_instance_name() @@ -862,7 +864,7 @@ class FederationHandlerRegistry: self._edu_type_to_instance = {} # type: Dict[str, str] def register_edu_handler( - self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]] + self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] ): """Sets the handler callable that will be used to handle an incoming federation EDU of the given type. diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 8e46957d15..5f1bf492c1 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py
@@ -188,7 +188,7 @@ class FederationRemoteSendQueue: for key in keys[:i]: del self.edus[key] - def notify_new_events(self, current_id): + def notify_new_events(self, max_token): """As per FederationSender""" # We don't need to replicate this as it gets sent down a different # stream. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 8bb17b3a05..604cfd1935 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py
@@ -40,7 +40,7 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import ReadReceipt +from synapse.types import ReadReceipt, RoomStreamToken from synapse.util.metrics import Measure, measure_func logger = logging.getLogger(__name__) @@ -154,10 +154,15 @@ class FederationSender: self._per_destination_queues[destination] = queue return queue - def notify_new_events(self, current_id: int) -> None: + def notify_new_events(self, max_token: RoomStreamToken) -> None: """This gets called when we have some new events we might want to send out to other servers. """ + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + current_id = max_token.stream + self._last_poked_id = max(current_id, self._last_poked_id) if self._is_processing: @@ -297,6 +302,8 @@ class FederationSender: sent_pdus_destination_dist_total.inc(len(destinations)) sent_pdus_destination_dist_count.inc() + assert pdu.internal_metadata.stream_ordering + # track the fact that we have a PDU for these destinations, # to allow us to perform catch-up later on if the remote is unreachable # for a while. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index bc99af3fdd..db8e456fe8 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py
@@ -158,6 +158,7 @@ class PerDestinationQueue: # yet know if we have anything to catch up (None) self._pending_pdus.append(pdu) else: + assert pdu.internal_metadata.stream_ordering self._catchup_last_skipped = pdu.internal_metadata.stream_ordering self.attempt_new_transaction() @@ -361,6 +362,7 @@ class PerDestinationQueue: last_successful_stream_ordering = ( final_pdu.internal_metadata.stream_ordering ) + assert last_successful_stream_ordering await self._store.set_destination_last_successful_stream_ordering( self._destination, last_successful_stream_ordering ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 17a10f622e..abe9168c78 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -35,7 +35,7 @@ class TransportLayerClient: def __init__(self, hs): self.server_name = hs.hostname - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() @log_function def get_room_state_ids(self, destination, room_id, event_id): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 3a6b95631e..cfd094e58f 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -144,7 +144,7 @@ class Authenticator: ): raise FederationDeniedError(origin) - if not json_request["signatures"]: + if origin is None or not json_request["signatures"]: raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) @@ -154,7 +154,7 @@ class Authenticator: ) logger.debug("Request from %s", origin) - request.authenticated_entity = origin + request.requester = origin # If we get a valid signed request from the other side, its probably # alive @@ -440,13 +440,13 @@ class FederationEventServlet(BaseFederationServlet): class FederationStateV1Servlet(BaseFederationServlet): - PATH = "/state/(?P<context>[^/]*)/?" + PATH = "/state/(?P<room_id>[^/]*)/?" - # This is when someone asks for all data for a given context. - async def on_GET(self, origin, content, query, context): - return await self.handler.on_context_state_request( + # This is when someone asks for all data for a given room. + async def on_GET(self, origin, content, query, room_id): + return await self.handler.on_room_state_request( origin, - context, + room_id, parse_string_from_args(query, "event_id", None, required=False), ) @@ -463,16 +463,16 @@ class FederationStateIdsServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet): - PATH = "/backfill/(?P<context>[^/]*)/?" + PATH = "/backfill/(?P<room_id>[^/]*)/?" - async def on_GET(self, origin, content, query, context): + async def on_GET(self, origin, content, query, room_id): versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) if not limit: return 400, {"error": "Did not include limit param"} - return await self.handler.on_backfill_request(origin, context, versions, limit) + return await self.handler.on_backfill_request(origin, room_id, versions, limit) class FederationQueryServlet(BaseFederationServlet): @@ -487,9 +487,9 @@ class FederationQueryServlet(BaseFederationServlet): class FederationMakeJoinServlet(BaseFederationServlet): - PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)" + PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" - async def on_GET(self, origin, _content, query, context, user_id): + async def on_GET(self, origin, _content, query, room_id, user_id): """ Args: origin (unicode): The authenticated server_name of the calling server @@ -511,16 +511,16 @@ class FederationMakeJoinServlet(BaseFederationServlet): supported_versions = ["1"] content = await self.handler.on_make_join_request( - origin, context, user_id, supported_versions=supported_versions + origin, room_id, user_id, supported_versions=supported_versions ) return 200, content class FederationMakeLeaveServlet(BaseFederationServlet): - PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)" + PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" - async def on_GET(self, origin, content, query, context, user_id): - content = await self.handler.on_make_leave_request(origin, context, user_id) + async def on_GET(self, origin, content, query, room_id, user_id): + content = await self.handler.on_make_leave_request(origin, room_id, user_id) return 200, content @@ -528,7 +528,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content, room_id) + content = await self.handler.on_send_leave_request(origin, content) return 200, (200, content) @@ -538,43 +538,43 @@ class FederationV2SendLeaveServlet(BaseFederationServlet): PREFIX = FEDERATION_V2_PREFIX async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content, room_id) + content = await self.handler.on_send_leave_request(origin, content) return 200, content class FederationEventAuthServlet(BaseFederationServlet): - PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - async def on_GET(self, origin, content, query, context, event_id): - return await self.handler.on_event_auth(origin, context, event_id) + async def on_GET(self, origin, content, query, room_id, event_id): + return await self.handler.on_event_auth(origin, room_id, event_id) class FederationV1SendJoinServlet(BaseFederationServlet): - PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - async def on_PUT(self, origin, content, query, context, event_id): - # TODO(paul): assert that context/event_id parsed from path actually + async def on_PUT(self, origin, content, query, room_id, event_id): + # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content, context) + content = await self.handler.on_send_join_request(origin, content) return 200, (200, content) class FederationV2SendJoinServlet(BaseFederationServlet): - PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, context, event_id): - # TODO(paul): assert that context/event_id parsed from path actually + async def on_PUT(self, origin, content, query, room_id, event_id): + # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content, context) + content = await self.handler.on_send_join_request(origin, content) return 200, content class FederationV1InviteServlet(BaseFederationServlet): - PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - async def on_PUT(self, origin, content, query, context, event_id): + async def on_PUT(self, origin, content, query, room_id, event_id): # We don't get a room version, so we have to assume its EITHER v1 or # v2. This is "fine" as the only difference between V1 and V2 is the # state resolution algorithm, and we don't use that for processing @@ -589,12 +589,12 @@ class FederationV1InviteServlet(BaseFederationServlet): class FederationV2InviteServlet(BaseFederationServlet): - PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, context, event_id): - # TODO(paul): assert that context/event_id parsed from path actually + async def on_PUT(self, origin, content, query, room_id, event_id): + # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content room_version = content["room_version"] @@ -616,9 +616,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id): - content = await self.handler.on_exchange_third_party_invite_request( - room_id, content - ) + content = await self.handler.on_exchange_third_party_invite_request(content) return 200, content @@ -1464,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N Args: hs (synapse.server.HomeServer): homeserver - resource (TransportLayerServer): resource class to register to + resource (JsonResource): resource class to register to authenticator (Authenticator): authenticator to use ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use servlet_groups (list[str], optional): List of servlet groups to register. diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index a86b3debc5..41cf07cc88 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py
@@ -22,7 +22,7 @@ attestations have a validity period so need to be periodically renewed. If a user leaves (or gets kicked out of) a group, either side can still use their attestation to "prove" their membership, until the attestation expires. Therefore attestations shouldn't be relied on to prove membership in important -cases, but can for less important situtations, e.g. showing a users membership +cases, but can for less important situations, e.g. showing a users membership of groups on their profile, showing flairs, etc. An attestation is a signed blob of json that looks like: diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index e5f85b472d..0d042cbfac 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py
@@ -113,7 +113,7 @@ class GroupsServerWorkerHandler: entry = await self.room_list_handler.generate_room_entry( room_id, len(joined_users), with_alias=False, allow_private=True ) - entry = dict(entry) # so we don't change whats cached + entry = dict(entry) # so we don't change what's cached entry.pop("room_id", None) room_entry["profile"] = entry @@ -550,7 +550,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): group_id, room_id, is_public=is_public ) else: - raise SynapseError(400, "Uknown config option") + raise SynapseError(400, "Unknown config option") return {} diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 286f0054be..bfebb0f644 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py
@@ -12,36 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .admin import AdminHandler -from .directory import DirectoryHandler -from .federation import FederationHandler -from .identity import IdentityHandler -from .search import SearchHandler - - -class Handlers: - - """ Deprecated. A collection of handlers. - - At some point most of the classes whose name ended "Handler" were - accessed through this class. - - However this makes it painful to unit test the handlers and to run cut - down versions of synapse that only use specific handlers because using a - single handler required creating all of the handlers. So some of the - handlers have been lifted out of the Handlers object and are now accessed - directly through the homeserver object itself. - - Any new handlers should follow the new pattern of being accessed through - the homeserver object and should not be added to the Handlers object. - - The remaining handlers should be moved out of the handlers object. - """ - - def __init__(self, hs): - self.federation_handler = FederationHandler(hs) - self.directory_handler = DirectoryHandler(hs) - self.admin_handler = AdminHandler(hs) - self.identity_handler = IdentityHandler(hs) - self.search_handler = SearchHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 0206320e96..d29b066a56 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py
@@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional import synapse.state import synapse.storage @@ -22,19 +23,22 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class BaseHandler: """ Common base class for the event handlers. + + Deprecated: new code should not use this. Instead, Handler classes should define the + fields they actually need. The utility methods should either be factored out to + standalone helper functions, or to different Handler classes. """ - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() # type: synapse.storage.DataStore self.auth = hs.get_auth() self.notifier = hs.get_notifier() @@ -56,7 +60,7 @@ class BaseHandler: clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, - ) + ) # type: Optional[Ratelimiter] else: self.admin_redaction_ratelimiter = None @@ -127,15 +131,15 @@ class BaseHandler: if guest_access != "can_join": if context: current_state_ids = await context.get_current_state_ids() - current_state = await self.store.get_events( + current_state_dict = await self.store.get_events( list(current_state_ids.values()) ) + current_state = list(current_state_dict.values()) else: - current_state = await self.state_handler.get_current_state( + current_state_map = await self.state_handler.get_current_state( event.room_id ) - - current_state = list(current_state.values()) + current_state = list(current_state_map.values()) logger.info("maybe_kick_guest_users %r", current_state) await self.kick_guest_users(current_state) @@ -169,7 +173,9 @@ class BaseHandler: # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = synapse.types.create_requester(target_user, is_guest=True) + requester = synapse.types.create_requester( + target_user, is_guest=True, authenticated_entity=self.server_name + ) handler = self.hs.get_room_member_handler() await handler.update_membership( requester, diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 9112a0ab86..341135822e 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py
@@ -12,16 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, List, Tuple + +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer class AccountDataEventSource: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - def get_current_key(self, direction="f"): + def get_current_key(self, direction: str = "f") -> int: return self.store.get_max_account_data_stream_id() - async def get_new_events(self, user, from_key, **kwargs): + async def get_new_events( + self, user: UserID, from_key: int, **kwargs + ) -> Tuple[List[JsonDict], int]: user_id = user.to_string() last_stream_id = from_key diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 4caf6d591a..664d09da1c 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py
@@ -18,19 +18,22 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import List +from typing import TYPE_CHECKING, List -from synapse.api.errors import StoreError +from synapse.api.errors import StoreError, SynapseError from synapse.logging.context import make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.types import UserID from synapse.util import stringutils +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class AccountValidityHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config self.store = self.hs.get_datastore() @@ -63,16 +66,11 @@ class AccountValidityHandler: self._raw_from = email.utils.parseaddr(self._from_string)[1] # Check the renewal emails to send and send them every 30min. - def send_emails(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "send_renewals", self._send_renewal_emails - ) - - self.clock.looping_call(send_emails, 30 * 60 * 1000) + if hs.config.run_background_tasks: + self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) - async def _send_renewal_emails(self): + @wrap_as_background_process("send_renewals") + async def _send_renewal_emails(self) -> None: """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` configuration, and sends renewal emails to all of these users as long as they @@ -86,11 +84,25 @@ class AccountValidityHandler: user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) - async def send_renewal_email_to_user(self, user_id: str): + async def send_renewal_email_to_user(self, user_id: str) -> None: + """ + Send a renewal email for a specific user. + + Args: + user_id: The user ID to send a renewal email for. + + Raises: + SynapseError if the user is not set to renew. + """ expiration_ts = await self.store.get_expiration_ts_for_user(user_id) + + # If this user isn't set to be expired, raise an error. + if expiration_ts is None: + raise SynapseError(400, "User has no expiration time: %s" % (user_id,)) + await self._send_renewal_email(user_id, expiration_ts) - async def _send_renewal_email(self, user_id: str, expiration_ts: int): + async def _send_renewal_email(self, user_id: str, expiration_ts: int) -> None: """Sends out a renewal email to every email address attached to the given user with a unique link allowing them to renew their account. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 1ce2091b46..37e63da9b1 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py
@@ -13,27 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging -from typing import List +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set from synapse.api.constants import Membership -from synapse.events import FrozenEvent -from synapse.types import RoomStreamToken, StateMap +from synapse.events import EventBase +from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class AdminHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.storage = hs.get_storage() self.state_store = self.storage.state - async def get_whois(self, user): + async def get_whois(self, user: UserID) -> JsonDict: connections = [] sessions = await self.store.get_user_ip_and_agents(user) @@ -53,7 +57,7 @@ class AdminHandler(BaseHandler): return ret - async def get_user(self, user): + async def get_user(self, user: UserID) -> Optional[JsonDict]: """Function to get user details""" ret = await self.store.get_user_by_id(user.to_string()) if ret: @@ -64,12 +68,12 @@ class AdminHandler(BaseHandler): ret["threepids"] = threepids return ret - async def export_user_data(self, user_id, writer): + async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any: """Write all data we have on the user to the given writer. Args: - user_id (str) - writer (ExfiltrationWriter) + user_id: The user ID to fetch data of. + writer: The writer to write to. Returns: Resolves when all data for a user has been written. @@ -88,7 +92,7 @@ class AdminHandler(BaseHandler): # We only try and fetch events for rooms the user has been in. If # they've been e.g. invited to a room without joining then we handle - # those seperately. + # those separately. rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id) for index, room in enumerate(rooms): @@ -128,7 +132,8 @@ class AdminHandler(BaseHandler): from_key = RoomStreamToken(0, 0) to_key = RoomStreamToken(None, stream_ordering) - written_events = set() # Events that we've processed in this room + # Events that we've processed in this room + written_events = set() # type: Set[str] # We need to track gaps in the events stream so that we can then # write out the state at those events. We do this by keeping track @@ -140,8 +145,8 @@ class AdminHandler(BaseHandler): # The reverse mapping to above, i.e. map from unseen event to events # that have the unseen event in their prev_events, i.e. the unseen - # events "children". dict[str, set[str]] - unseen_to_child_events = {} + # events "children". + unseen_to_child_events = {} # type: Dict[str, Set[str]] # We fetch events in the room the user could see by fetching *all* # events that we have and then filtering, this isn't the most @@ -197,38 +202,46 @@ class AdminHandler(BaseHandler): return writer.finished() -class ExfiltrationWriter: +class ExfiltrationWriter(metaclass=abc.ABCMeta): """Interface used to specify how to write exported data. """ - def write_events(self, room_id: str, events: List[FrozenEvent]): + @abc.abstractmethod + def write_events(self, room_id: str, events: List[EventBase]) -> None: """Write a batch of events for a room. """ - pass + raise NotImplementedError() - def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]): + @abc.abstractmethod + def write_state( + self, room_id: str, event_id: str, state: StateMap[EventBase] + ) -> None: """Write the state at the given event in the room. This only gets called for backward extremities rather than for each event. """ - pass + raise NotImplementedError() - def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]): + @abc.abstractmethod + def write_invite( + self, room_id: str, event: EventBase, state: StateMap[dict] + ) -> None: """Write an invite for the room, with associated invite state. Args: - room_id - event - state: A subset of the state at the - invite, with a subset of the event keys (type, state_key - content and sender) + room_id: The room ID the invite is for. + event: The invite event. + state: A subset of the state at the invite, with a subset of the + event keys (type, state_key content and sender). """ + raise NotImplementedError() - def finished(self): - """Called when all data has succesfully been exported and written. + @abc.abstractmethod + def finished(self) -> Any: + """Called when all data has successfully been exported and written. This functions return value is passed to the caller of `export_user_data`. """ - pass + raise NotImplementedError() diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9d4e87dad6..5c6458eb52 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging +from typing import TYPE_CHECKING, Dict, List, Optional, Union from prometheus_client import Counter @@ -21,21 +21,32 @@ from twisted.internet import defer import synapse from synapse.api.constants import EventTypes +from synapse.appservice import ApplicationService +from synapse.events import EventBase +from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import ( event_processing_loop_counter, event_processing_loop_room_count, ) -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) +from synapse.storage.databases.main.directory import RoomAliasMapping +from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") class ApplicationServicesHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() @@ -43,19 +54,22 @@ class ApplicationServicesHandler: self.started_scheduler = False self.clock = hs.get_clock() self.notify_appservices = hs.config.notify_appservices + self.event_sources = hs.get_event_sources() self.current_max = 0 self.is_processing = False - async def notify_interested_services(self, current_id): + def notify_interested_services(self, max_token: RoomStreamToken): """Notifies (pushes) all application services interested in this event. Pushing is done asynchronously, so this method won't block for any prolonged length of time. - - Args: - current_id(int): The current maximum ID. """ + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + current_id = max_token.stream + services = self.store.get_app_services() if not services or not self.notify_appservices: return @@ -64,6 +78,12 @@ class ApplicationServicesHandler: if self.is_processing: return + # We only start a new background process if necessary rather than + # optimistically (to cut down on overhead). + self._notify_interested_services(max_token) + + @wrap_as_background_process("notify_interested_services") + async def _notify_interested_services(self, max_token: RoomStreamToken): with Measure(self.clock, "notify_interested_services"): self.is_processing = True try: @@ -79,7 +99,7 @@ class ApplicationServicesHandler: if not events: break - events_by_room = {} + events_by_room = {} # type: Dict[str, List[EventBase]] for event in events: events_by_room.setdefault(event.room_id, []).append(event) @@ -158,11 +178,139 @@ class ApplicationServicesHandler: finally: self.is_processing = False - async def query_user_exists(self, user_id): + def notify_interested_services_ephemeral( + self, + stream_key: str, + new_token: Optional[int], + users: Collection[Union[str, UserID]] = [], + ): + """This is called by the notifier in the background + when a ephemeral event handled by the homeserver. + + This will determine which appservices + are interested in the event, and submit them. + + Events will only be pushed to appservices + that have opted into ephemeral events + + Args: + stream_key: The stream the event came from. + new_token: The latest stream token + users: The user(s) involved with the event. + """ + if not self.notify_appservices: + return + + if stream_key not in ("typing_key", "receipt_key", "presence_key"): + return + + services = [ + service + for service in self.store.get_app_services() + if service.supports_ephemeral + ] + if not services: + return + + # We only start a new background process if necessary rather than + # optimistically (to cut down on overhead). + self._notify_interested_services_ephemeral( + services, stream_key, new_token, users + ) + + @wrap_as_background_process("notify_interested_services_ephemeral") + async def _notify_interested_services_ephemeral( + self, + services: List[ApplicationService], + stream_key: str, + new_token: Optional[int], + users: Collection[Union[str, UserID]], + ): + logger.debug("Checking interested services for %s" % (stream_key)) + with Measure(self.clock, "notify_interested_services_ephemeral"): + for service in services: + # Only handle typing if we have the latest token + if stream_key == "typing_key" and new_token is not None: + events = await self._handle_typing(service, new_token) + if events: + self.scheduler.submit_ephemeral_events_for_as(service, events) + # We don't persist the token for typing_key for performance reasons + elif stream_key == "receipt_key": + events = await self._handle_receipts(service) + if events: + self.scheduler.submit_ephemeral_events_for_as(service, events) + await self.store.set_type_stream_id_for_appservice( + service, "read_receipt", new_token + ) + elif stream_key == "presence_key": + events = await self._handle_presence(service, users) + if events: + self.scheduler.submit_ephemeral_events_for_as(service, events) + await self.store.set_type_stream_id_for_appservice( + service, "presence", new_token + ) + + async def _handle_typing( + self, service: ApplicationService, new_token: int + ) -> List[JsonDict]: + typing_source = self.event_sources.sources["typing"] + # Get the typing events from just before current + typing, _ = await typing_source.get_new_events_as( + service=service, + # For performance reasons, we don't persist the previous + # token in the DB and instead fetch the latest typing information + # for appservices. + from_key=new_token - 1, + ) + return typing + + async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]: + from_key = await self.store.get_type_stream_id_for_appservice( + service, "read_receipt" + ) + receipts_source = self.event_sources.sources["receipt"] + receipts, _ = await receipts_source.get_new_events_as( + service=service, from_key=from_key + ) + return receipts + + async def _handle_presence( + self, service: ApplicationService, users: Collection[Union[str, UserID]] + ) -> List[JsonDict]: + events = [] # type: List[JsonDict] + presence_source = self.event_sources.sources["presence"] + from_key = await self.store.get_type_stream_id_for_appservice( + service, "presence" + ) + for user in users: + if isinstance(user, str): + user = UserID.from_string(user) + + interested = await service.is_interested_in_presence(user, self.store) + if not interested: + continue + presence_events, _ = await presence_source.get_new_events( + user=user, service=service, from_key=from_key, + ) + time_now = self.clock.time_msec() + events.extend( + { + "type": "m.presence", + "sender": event.user_id, + "content": format_user_presence_state( + event, time_now, include_user_id=False + ), + } + for event in presence_events + ) + + return events + + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. Args: - user_id(str): The user to query if they exist on any AS. + user_id: The user to query if they exist on any AS. Returns: True if this user exists on at least one application service. """ @@ -173,11 +321,13 @@ class ApplicationServicesHandler: return True return False - async def query_room_alias_exists(self, room_alias): + async def query_room_alias_exists( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: """Check if an application service knows this room alias exists. Args: - room_alias(RoomAlias): The room alias to query. + room_alias: The room alias to query. Returns: namedtuple: with keys "room_id" and "servers" or None if no association can be found. @@ -193,10 +343,13 @@ class ApplicationServicesHandler: ) if is_known_alias: # the alias exists now so don't query more ASes. - result = await self.store.get_association_from_room_alias(room_alias) - return result + return await self.store.get_association_from_room_alias(room_alias) - async def query_3pe(self, kind, protocol, fields): + return None + + async def query_3pe( + self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]] + ) -> List[JsonDict]: services = self._get_services_for_3pn(protocol) results = await make_deferred_yieldable( @@ -218,9 +371,11 @@ class ApplicationServicesHandler: return ret - async def get_3pe_protocols(self, only_protocol=None): + async def get_3pe_protocols( + self, only_protocol: Optional[str] = None + ) -> Dict[str, JsonDict]: services = self.store.get_app_services() - protocols = {} + protocols = {} # type: Dict[str, List[JsonDict]] # Collect up all the individual protocol responses out of the ASes for s in services: @@ -236,7 +391,7 @@ class ApplicationServicesHandler: if info is not None: protocols[p].append(info) - def _merge_instances(infos): + def _merge_instances(infos: List[JsonDict]) -> JsonDict: if not infos: return {} @@ -251,19 +406,17 @@ class ApplicationServicesHandler: return combined - for p in protocols.keys(): - protocols[p] = _merge_instances(protocols[p]) - - return protocols + return {p: _merge_instances(protocols[p]) for p in protocols.keys()} - async def _get_services_for_event(self, event): + async def _get_services_for_event( + self, event: EventBase + ) -> List[ApplicationService]: """Retrieve a list of application services interested in this event. Args: - event(Event): The event to check. Can be None if alias_list is not. + event: The event to check. Can be None if alias_list is not. Returns: - list<ApplicationService>: A list of services interested in this - event based on the service regex. + A list of services interested in this event based on the service regex. """ services = self.store.get_app_services() @@ -277,17 +430,15 @@ class ApplicationServicesHandler: return interested_list - def _get_services_for_user(self, user_id): + def _get_services_for_user(self, user_id: str) -> List[ApplicationService]: services = self.store.get_app_services() - interested_list = [s for s in services if (s.is_interested_in_user(user_id))] - return interested_list + return [s for s in services if (s.is_interested_in_user(user_id))] - def _get_services_for_3pn(self, protocol): + def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]: services = self.store.get_app_services() - interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] - return interested_list + return [s for s in services if s.is_interested_in_protocol(protocol)] - async def _is_unknown_user(self, user_id): + async def _is_unknown_user(self, user_id: str) -> bool: if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. @@ -302,9 +453,8 @@ class ApplicationServicesHandler: service_list = [s for s in services if s.sender == user_id] return len(service_list) == 0 - async def _check_user_exists(self, user_id): + async def _check_user_exists(self, user_id: str) -> bool: unknown_user = await self._is_unknown_user(user_id) if unknown_user: - exists = await self.query_user_exists(user_id) - return exists + return await self.query_user_exists(user_id) return True diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 00eae92052..f4434673dc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,17 +14,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import time import unicodedata import urllib.parse -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, +) import attr -import bcrypt # type: ignore[import] +import bcrypt import pymacaroons +from twisted.web.http import Request + from synapse.api.constants import LoginType from synapse.api.errors import ( AuthError, @@ -44,11 +58,15 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils +from synapse.util.async_helpers import maybe_awaitable from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -149,11 +167,7 @@ class SsoLoginExtraAttributes: class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] @@ -164,42 +178,45 @@ class AuthHandler(BaseHandler): self.bcrypt_rounds = hs.config.bcrypt_rounds + # we can't use hs.get_module_api() here, because to do so will create an + # import loop. + # + # TODO: refactor this class to separate the lower-level stuff that + # ModuleApi can use from the higher-level stuff that uses ModuleApi, as + # better way to break the loop account_handler = ModuleApi(hs, self) + self.password_providers = [ - module(config=config, account_handler=account_handler) + PasswordProvider.load(module, config, account_handler) for module, config in hs.config.password_providers ] - logger.info("Extra password_providers: %r", self.password_providers) + logger.info("Extra password_providers: %s", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled - self._sso_enabled = ( - hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled - ) + self._password_localdb_enabled = hs.config.password_localdb_enabled + + # start out by assuming PASSWORD is enabled; we will remove it later if not. + login_types = set() + if self._password_localdb_enabled: + login_types.add(LoginType.PASSWORD) - # we keep this as a list despite the O(N^2) implication so that we can - # keep PASSWORD first and avoid confusing clients which pick the first - # type in the list. (NB that the spec doesn't require us to do so and - # clients which favour types that they don't understand over those that - # they do are technically broken) - login_types = [] - if self._password_enabled: - login_types.append(LoginType.PASSWORD) for provider in self.password_providers: - if hasattr(provider, "get_supported_login_types"): - for t in provider.get_supported_login_types().keys(): - if t not in login_types: - login_types.append(t) - self._supported_login_types = login_types - # Login types and UI Auth types have a heavy overlap, but are not - # necessarily identical. Login types have SSO (and other login types) - # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. - ui_auth_types = login_types.copy() - if self._sso_enabled: - ui_auth_types.append(LoginType.SSO) - self._supported_ui_auth_types = ui_auth_types + login_types.update(provider.get_supported_login_types().keys()) + + if not self._password_enabled: + login_types.discard(LoginType.PASSWORD) + + # Some clients just pick the first type in the list. In this case, we want + # them to use PASSWORD (rather than token or whatever), so we want to make sure + # that comes first, where it's present. + self._supported_login_types = [] + if LoginType.PASSWORD in login_types: + self._supported_login_types.append(LoginType.PASSWORD) + login_types.remove(LoginType.PASSWORD) + self._supported_login_types.extend(login_types) # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. @@ -209,10 +226,20 @@ class AuthHandler(BaseHandler): burst_count=self.hs.config.rc_login_failed_attempts.burst_count, ) + # The number of seconds to keep a UI auth session active. + self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout + + # Ratelimitier for failed /login attempts + self._failed_login_attempts_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) + self._clock = self.hs.get_clock() # Expire old UI auth sessions after a period of time. - if hs.config.worker_app is None: + if hs.config.run_background_tasks: self._clock.looping_call( run_as_background_process, 5 * 60 * 1000, @@ -259,7 +286,7 @@ class AuthHandler(BaseHandler): request_body: Dict[str, Any], clientip: str, description: str, - ) -> Tuple[dict, str]: + ) -> Tuple[dict, Optional[str]]: """ Checks that the user is who they claim to be, via a UI auth. @@ -286,7 +313,8 @@ class AuthHandler(BaseHandler): have been given only in a previous call). 'session_id' is the ID of this session, either passed in by the - client or assigned by this call + client or assigned by this call. This is None if UI auth was + skipped (by re-using a previous validation). Raises: InteractiveAuthIncompleteError if the client has not yet completed @@ -300,13 +328,26 @@ class AuthHandler(BaseHandler): """ + if self._ui_auth_session_timeout: + last_validated = await self.store.get_access_token_last_validated( + requester.access_token_id + ) + if self.clock.time_msec() - last_validated < self._ui_auth_session_timeout: + # Return the input parameters, minus the auth key, which matches + # the logic in check_ui_auth. + request_body.pop("auth", None) + return request_body, None + user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) # build a list of supported flows - flows = [[login_type] for login_type in self._supported_ui_auth_types] + supported_ui_auth_types = await self._get_available_ui_auth_types( + requester.user + ) + flows = [[login_type] for login_type in supported_ui_auth_types] try: result, params, session_id = await self.check_ui_auth( @@ -318,7 +359,7 @@ class AuthHandler(BaseHandler): raise # find the completed login type - for login_type in self._supported_ui_auth_types: + for login_type in supported_ui_auth_types: if login_type not in result: continue @@ -332,8 +373,46 @@ class AuthHandler(BaseHandler): if user_id != requester.user.to_string(): raise AuthError(403, "Invalid auth") + # Note that the access token has been validated. + await self.store.update_access_token_last_validated(requester.access_token_id) + return params, session_id + async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: + """Get a list of the authentication types this user can use + """ + + ui_auth_types = set() + + # if the HS supports password auth, and the user has a non-null password, we + # support password auth + if self._password_localdb_enabled and self._password_enabled: + lookupres = await self._find_user_id_and_pwd_hash(user.to_string()) + if lookupres: + _, password_hash = lookupres + if password_hash: + ui_auth_types.add(LoginType.PASSWORD) + + # also allow auth from password providers + for provider in self.password_providers: + for t in provider.get_supported_login_types().keys(): + if t == LoginType.PASSWORD and not self._password_enabled: + continue + ui_auth_types.add(t) + + # if sso is enabled, allow the user to log in via SSO iff they have a mapping + # from sso to mxid. + if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled: + if await self.store.get_external_ids_by_user(user.to_string()): + ui_auth_types.add(LoginType.SSO) + + # Our CAS impl does not (yet) correctly register users in user_external_ids, + # so always offer that if it's available. + if self.hs.config.cas.cas_enabled: + ui_auth_types.add(LoginType.SSO) + + return ui_auth_types + def get_enabled_auth_types(self): """Return the enabled user-interactive authentication types @@ -390,13 +469,10 @@ class AuthHandler(BaseHandler): all the stages in any of the permitted flows. """ - authdict = None sid = None # type: Optional[str] - if clientdict and "auth" in clientdict: - authdict = clientdict["auth"] - del clientdict["auth"] - if "session" in authdict: - sid = authdict["session"] + authdict = clientdict.pop("auth", {}) + if "session" in authdict: + sid = authdict["session"] # Convert the URI and method to strings. uri = request.uri.decode("utf-8") @@ -463,9 +539,7 @@ class AuthHandler(BaseHandler): # authentication flow. await self.store.set_ui_auth_clientdict(sid, clientdict) - user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ - 0 - ].decode("ascii", "surrogateescape") + user_agent = request.get_user_agent("") await self.store.add_user_agent_ip_to_ui_auth_session( session.session_id, user_agent, clientip @@ -503,6 +577,8 @@ class AuthHandler(BaseHandler): creds = await self.store.get_completed_ui_auth_stages(session.session_id) for f in flows: + # If all the required credentials have been supplied, the user has + # successfully completed the UI auth process! if len(set(f) - set(creds)) == 0: # it's very useful to know what args are stored, but this can # include the password in the case of registering, so only log @@ -623,14 +699,8 @@ class AuthHandler(BaseHandler): res = await checker.check_auth(authdict, clientip=clientip) return res - # build a v1-login-style dict out of the authdict and fall back to the - # v1 code - user_id = authdict.get("user") - - if user_id is None: - raise SynapseError(400, "", Codes.MISSING_PARAM) - - (canonical_id, callback) = await self.validate_login(user_id, authdict) + # fall back to the v1 login flow + canonical_id, _ = await self.validate_login(authdict) return canonical_id def _get_params_recaptcha(self) -> dict: @@ -679,13 +749,18 @@ class AuthHandler(BaseHandler): } async def get_access_token_for_user_id( - self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] - ): + self, + user_id: str, + device_id: Optional[str], + valid_until_ms: Optional[int], + puppets_user_id: Optional[str] = None, + is_appservice_ghost: bool = False, + ) -> str: """ Creates a new access token for the user with the given user ID. The user is assumed to have been authenticated by some other - machanism (e.g. CAS), and the user_id converted to the canonical case. + mechanism (e.g. CAS), and the user_id converted to the canonical case. The device will be recorded in the table if it is not there already. @@ -696,6 +771,7 @@ class AuthHandler(BaseHandler): we should always have a device ID) valid_until_ms: when the token is valid until. None for no expiry. + is_appservice_ghost: Whether the user is an application ghost user Returns: The access token for the user's session. Raises: @@ -706,13 +782,29 @@ class AuthHandler(BaseHandler): fmt_expiry = time.strftime( " until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0) ) - logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry) - await self.auth.check_auth_blocking(user_id) + if puppets_user_id: + logger.info( + "Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry + ) + else: + logger.info( + "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry + ) + + if ( + not is_appservice_ghost + or self.hs.config.appservice.track_appservice_user_ips + ): + await self.auth.check_auth_blocking(user_id) access_token = self.macaroon_gen.generate_access_token(user_id) await self.store.add_access_token_to_user( - user_id, access_token, device_id, valid_until_ms + user_id=user_id, + token=access_token, + device_id=device_id, + valid_until_ms=valid_until_ms, + puppets_user_id=puppets_user_id, ) # the device *should* have been registered before we got here; however, @@ -789,17 +881,17 @@ class AuthHandler(BaseHandler): return self._supported_login_types async def validate_login( - self, username: str, login_submission: Dict[str, Any] - ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: + self, login_submission: Dict[str, Any], ratelimit: bool = False, + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Authenticates the user for the /login API - Also used by the user-interactive auth flow to validate - m.login.password auth types. + Also used by the user-interactive auth flow to validate auth types which don't + have an explicit UIA handler, including m.password.auth. Args: - username: username supplied by the user login_submission: the whole of the login submission (including 'type' and other relevant fields) + ratelimit: whether to apply the failed_login_attempt ratelimiter Returns: A tuple of the canonical user id, and optional callback to be called once the access token and device id are issued @@ -808,38 +900,160 @@ class AuthHandler(BaseHandler): SynapseError if there was a problem with the request LoginError if there was an authentication problem. """ - - if username.startswith("@"): - qualified_user_id = username - else: - qualified_user_id = UserID(username, self.hs.hostname).to_string() - login_type = login_submission.get("type") - known_login_type = False + if not isinstance(login_type, str): + raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM) + + # ideally, we wouldn't be checking the identifier unless we know we have a login + # method which uses it (https://github.com/matrix-org/synapse/issues/8836) + # + # But the auth providers' check_auth interface requires a username, so in + # practice we can only support login methods which we can map to a username + # anyway. # special case to check for "password" for the check_password interface # for the auth providers password = login_submission.get("password") - if login_type == LoginType.PASSWORD: if not self._password_enabled: raise SynapseError(400, "Password login has been disabled.") - if not password: - raise SynapseError(400, "Missing parameter: password") + if not isinstance(password, str): + raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM) - for provider in self.password_providers: - if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: - known_login_type = True - is_valid = await provider.check_password(qualified_user_id, password) - if is_valid: - return qualified_user_id, None + # map old-school login fields into new-school "identifier" fields. + identifier_dict = convert_client_dict_legacy_fields_to_identifier( + login_submission + ) - if not hasattr(provider, "get_supported_login_types") or not hasattr( - provider, "check_auth" - ): - # this password provider doesn't understand custom login types - continue + # convert phone type identifiers to generic threepids + if identifier_dict["type"] == "m.id.phone": + identifier_dict = login_id_phone_to_thirdparty(identifier_dict) + + # convert threepid identifiers to user IDs + if identifier_dict["type"] == "m.id.thirdparty": + address = identifier_dict.get("address") + medium = identifier_dict.get("medium") + + if medium is None or address is None: + raise SynapseError(400, "Invalid thirdparty identifier") + + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) + if medium == "email": + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) + + # We also apply account rate limiting using the 3PID as a key, as + # otherwise using 3PID bypasses the ratelimiting based on user ID. + if ratelimit: + self._failed_login_attempts_ratelimiter.ratelimit( + (medium, address), update=False + ) + + # Check for login providers that support 3pid login types + if login_type == LoginType.PASSWORD: + # we've already checked that there is a (valid) password field + assert isinstance(password, str) + ( + canonical_user_id, + callback_3pid, + ) = await self.check_password_provider_3pid(medium, address, password) + if canonical_user_id: + # Authentication through password provider and 3pid succeeded + return canonical_user_id, callback_3pid + + # No password providers were able to handle this 3pid + # Check local store + user_id = await self.hs.get_datastore().get_user_id_by_threepid( + medium, address + ) + if not user_id: + logger.warning( + "unknown 3pid identifier medium %s, address %r", medium, address + ) + # We mark that we've failed to log in here, as + # `check_password_provider_3pid` might have returned `None` due + # to an incorrect password, rather than the account not + # existing. + # + # If it returned None but the 3PID was bound then we won't hit + # this code path, which is fine as then the per-user ratelimit + # will kick in below. + if ratelimit: + self._failed_login_attempts_ratelimiter.can_do_action( + (medium, address) + ) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + + identifier_dict = {"type": "m.id.user", "user": user_id} + + # by this point, the identifier should be an m.id.user: if it's anything + # else, we haven't understood it. + if identifier_dict["type"] != "m.id.user": + raise SynapseError(400, "Unknown login identifier type") + + username = identifier_dict.get("user") + if not username: + raise SynapseError(400, "User identifier is missing 'user' key") + + if username.startswith("@"): + qualified_user_id = username + else: + qualified_user_id = UserID(username, self.hs.hostname).to_string() + + # Check if we've hit the failed ratelimit (but don't update it) + if ratelimit: + self._failed_login_attempts_ratelimiter.ratelimit( + qualified_user_id.lower(), update=False + ) + + try: + return await self._validate_userid_login(username, login_submission) + except LoginError: + # The user has failed to log in, so we need to update the rate + # limiter. Using `can_do_action` avoids us raising a ratelimit + # exception and masking the LoginError. The actual ratelimiting + # should have happened above. + if ratelimit: + self._failed_login_attempts_ratelimiter.can_do_action( + qualified_user_id.lower() + ) + raise + + async def _validate_userid_login( + self, username: str, login_submission: Dict[str, Any], + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: + """Helper for validate_login + Handles login, once we've mapped 3pids onto userids + + Args: + username: the username, from the identifier dict + login_submission: the whole of the login submission + (including 'type' and other relevant fields) + Returns: + A tuple of the canonical user id, and optional callback + to be called once the access token and device id are issued + Raises: + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. + """ + if username.startswith("@"): + qualified_user_id = username + else: + qualified_user_id = UserID(username, self.hs.hostname).to_string() + + login_type = login_submission.get("type") + # we already checked that we have a valid login type + assert isinstance(login_type, str) + + known_login_type = False + + for provider in self.password_providers: supported_login_types = provider.get_supported_login_types() if login_type not in supported_login_types: # this password provider doesn't understand this login type @@ -864,15 +1078,17 @@ class AuthHandler(BaseHandler): result = await provider.check_auth(username, login_type, login_dict) if result: - if isinstance(result, str): - result = (result, None) return result - if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: + if login_type == LoginType.PASSWORD and self._password_localdb_enabled: known_login_type = True + # we've already checked that there is a (valid) password field + password = login_submission["password"] + assert isinstance(password, str) + canonical_user_id = await self._check_local_password( - qualified_user_id, password # type: ignore + qualified_user_id, password ) if canonical_user_id: @@ -887,7 +1103,7 @@ class AuthHandler(BaseHandler): async def check_password_provider_3pid( self, medium: str, address: str, password: str - ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]: + ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Check if a password provider is able to validate a thirdparty login Args: @@ -903,19 +1119,9 @@ class AuthHandler(BaseHandler): unsuccessful, `user_id` and `callback` are both `None`. """ for provider in self.password_providers: - if hasattr(provider, "check_3pid_auth"): - # This function is able to return a deferred that either - # resolves None, meaning authentication failure, or upon - # success, to a str (which is the user_id) or a tuple of - # (user_id, callback_func), where callback_func should be run - # after we've finished everything else - result = await provider.check_3pid_auth(medium, address, password) - if result: - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - result = (result, None) - return result + result = await provider.check_3pid_auth(medium, address, password) + if result: + return result return None, None @@ -973,21 +1179,16 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: - if hasattr(provider, "on_logged_out"): - # This might return an awaitable, if it does block the log out - # until it completes. - result = provider.on_logged_out( - user_id=str(user_info["user"]), - device_id=user_info["device_id"], - access_token=access_token, - ) - if inspect.isawaitable(result): - await result + await provider.on_logged_out( + user_id=user_info.user_id, + device_id=user_info.device_id, + access_token=access_token, + ) # delete pushers associated with this access token - if user_info["token_id"] is not None: + if user_info.token_id is not None: await self.hs.get_pusherpool().remove_pushers_by_access_token( - str(user_info["user"]), (user_info["token_id"],) + user_info.user_id, (user_info.token_id,) ) async def delete_access_tokens_for_user( @@ -1011,11 +1212,10 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: - if hasattr(provider, "on_logged_out"): - for token, token_id, device_id in tokens_and_devices: - await provider.on_logged_out( - user_id=user_id, device_id=device_id, access_token=token - ) + for token, token_id, device_id in tokens_and_devices: + await provider.on_logged_out( + user_id=user_id, device_id=device_id, access_token=token + ) # delete pushers associated with the access tokens await self.hs.get_pusherpool().remove_pushers_by_access_token( @@ -1073,7 +1273,7 @@ class AuthHandler(BaseHandler): if medium == "email": address = canonicalise_email(address) - identity_handler = self.hs.get_handlers().identity_handler + identity_handler = self.hs.get_identity_handler() result = await identity_handler.try_unbind_threepid( user_id, {"medium": medium, "address": address, "id_server": id_server} ) @@ -1115,20 +1315,22 @@ class AuthHandler(BaseHandler): Whether self.hash(password) == stored_hash. """ - def _do_validate_hash(): + def _do_validate_hash(checked_hash: bytes): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), - stored_hash, + checked_hash, ) if stored_hash: if not isinstance(stored_hash, bytes): stored_hash = stored_hash.encode("ascii") - return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash) + return await defer_to_thread( + self.hs.get_reactor(), _do_validate_hash, stored_hash + ) else: return False @@ -1152,15 +1354,14 @@ class AuthHandler(BaseHandler): ) async def complete_sso_ui_auth( - self, registered_user_id: str, session_id: str, request: SynapseRequest, + self, registered_user_id: str, session_id: str, request: Request, ): """Having figured out a mxid for this user, complete the HTTP request Args: registered_user_id: The registered user ID to complete SSO login for. + session_id: The ID of the user-interactive auth session. request: The request to complete. - client_redirect_url: The URL to which to redirect the user at the end of the - process. """ # Mark the stage of the authentication as successful. # Save the user who authenticated with SSO, this will be used to ensure @@ -1176,7 +1377,7 @@ class AuthHandler(BaseHandler): async def complete_sso_login( self, registered_user_id: str, - request: SynapseRequest, + request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, ): @@ -1204,7 +1405,7 @@ class AuthHandler(BaseHandler): def _complete_sso_login( self, registered_user_id: str, - request: SynapseRequest, + request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, ): @@ -1337,3 +1538,127 @@ class MacaroonGenerator: macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon + + +class PasswordProvider: + """Wrapper for a password auth provider module + + This class abstracts out all of the backwards-compatibility hacks for + password providers, to provide a consistent interface. + """ + + @classmethod + def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider": + try: + pp = module(config=config, account_handler=module_api) + except Exception as e: + logger.error("Error while initializing %r: %s", module, e) + raise + return cls(pp, module_api) + + def __init__(self, pp, module_api: ModuleApi): + self._pp = pp + self._module_api = module_api + + self._supported_login_types = {} + + # grandfather in check_password support + if hasattr(self._pp, "check_password"): + self._supported_login_types[LoginType.PASSWORD] = ("password",) + + g = getattr(self._pp, "get_supported_login_types", None) + if g: + self._supported_login_types.update(g()) + + def __str__(self): + return str(self._pp) + + def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: + """Get the login types supported by this password provider + + Returns a map from a login type identifier (such as m.login.password) to an + iterable giving the fields which must be provided by the user in the submission + to the /login API. + + This wrapper adds m.login.password to the list if the underlying password + provider supports the check_password() api. + """ + return self._supported_login_types + + async def check_auth( + self, username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + """Check if the user has presented valid login credentials + + This wrapper also calls check_password() if the underlying password provider + supports the check_password() api and the login type is m.login.password. + + Args: + username: user id presented by the client. Either an MXID or an unqualified + username. + + login_type: the login type being attempted - one of the types returned by + get_supported_login_types() + + login_dict: the dictionary of login secrets passed by the client. + + Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the + user, and `callback` is an optional callback which will be called with the + result from the /login call (including access_token, device_id, etc.) + """ + # first grandfather in a call to check_password + if login_type == LoginType.PASSWORD: + g = getattr(self._pp, "check_password", None) + if g: + qualified_user_id = self._module_api.get_qualified_user_id(username) + is_valid = await self._pp.check_password( + qualified_user_id, login_dict["password"] + ) + if is_valid: + return qualified_user_id, None + + g = getattr(self._pp, "check_auth", None) + if not g: + return None + result = await g(username, login_type, login_dict) + + # Check if the return value is a str or a tuple + if isinstance(result, str): + # If it's a str, set callback function to None + return result, None + + return result + + async def check_3pid_auth( + self, medium: str, address: str, password: str + ) -> Optional[Tuple[str, Optional[Callable]]]: + g = getattr(self._pp, "check_3pid_auth", None) + if not g: + return None + + # This function is able to return a deferred that either + # resolves None, meaning authentication failure, or upon + # success, to a str (which is the user_id) or a tuple of + # (user_id, callback_func), where callback_func should be run + # after we've finished everything else + result = await g(medium, address, password) + + # Check if the return value is a str or a tuple + if isinstance(result, str): + # If it's a str, set callback function to None + return result, None + + return result + + async def on_logged_out( + self, user_id: str, device_id: Optional[str], access_token: str + ) -> None: + g = getattr(self._pp, "on_logged_out", None) + if not g: + return + + # This might return an awaitable, if it does block the log out + # until it completes. + await maybe_awaitable( + g(user_id=user_id, device_id=device_id, access_token=access_token,) + ) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index a4cc4b9a5a..fca210a5a6 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py
@@ -13,30 +13,57 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import urllib -from typing import Dict, Optional, Tuple +import urllib.parse +from typing import TYPE_CHECKING, Dict, Optional from xml.etree import ElementTree as ET +import attr + from twisted.web.client import PartialDownloadError -from synapse.api.errors import Codes, LoginError +from synapse.api.errors import HttpResponseException +from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.types import UserID, map_username_to_mxid_localpart +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) +class CasError(Exception): + """Used to catch errors when validating the CAS ticket. + """ + + def __init__(self, error, error_description=None): + self.error = error + self.error_description = error_description + + def __str__(self): + if self.error_description: + return "{}: {}".format(self.error, self.error_description) + return self.error + + +@attr.s(slots=True, frozen=True) +class CasResponse: + username = attr.ib(type=str) + attributes = attr.ib(type=Dict[str, Optional[str]]) + + class CasHandler: """ Utility class for to handle the response from a CAS SSO service. Args: - hs (synapse.server.HomeServer) + hs """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname + self._store = hs.get_datastore() self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -47,6 +74,11 @@ class CasHandler: self._http_client = hs.get_proxied_http_client() + # identifier for the external_ids table + self._auth_provider_id = "cas" + + self._sso_handler = hs.get_sso_handler() + def _build_service_param(self, args: Dict[str, str]) -> str: """ Generates a value to use as the "service" parameter when redirecting or @@ -66,14 +98,20 @@ class CasHandler: async def _validate_ticket( self, ticket: str, service_args: Dict[str, str] - ) -> Tuple[str, Optional[str]]: + ) -> CasResponse: """ - Validate a CAS ticket with the server, parse the response, and return the user and display name. + Validate a CAS ticket with the server, and return the parsed the response. Args: ticket: The CAS ticket from the client. service_args: Additional arguments to include in the service URL. Should be the same as those passed to `get_redirect_url`. + + Raises: + CasError: If there's an error parsing the CAS response. + + Returns: + The parsed CAS response. """ uri = self._cas_server_url + "/proxyValidate" args = { @@ -86,66 +124,65 @@ class CasHandler: # Twisted raises this error if the connection is closed, # even if that's being used old-http style to signal end-of-data body = pde.response + except HttpResponseException as e: + description = ( + ( + 'Authorization server responded with a "{status}" error ' + "while exchanging the authorization code." + ).format(status=e.code), + ) + raise CasError("server_error", description) from e - user, attributes = self._parse_cas_response(body) - displayname = attributes.pop(self._cas_displayname_attribute, None) - - for required_attribute, required_value in self._cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - return user, displayname + return self._parse_cas_response(body) - def _parse_cas_response( - self, cas_response_body: bytes - ) -> Tuple[str, Dict[str, Optional[str]]]: + def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse: """ Retrieve the user and other parameters from the CAS response. Args: cas_response_body: The response from the CAS query. + Raises: + CasError: If there's an error parsing the CAS response. + Returns: - A tuple of the user and a mapping of other attributes. + The parsed CAS response. """ + + # Ensure the response is valid. + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise CasError( + "missing_service_response", + "root of CAS response is not serviceResponse", + ) + + success = root[0].tag.endswith("authenticationSuccess") + if not success: + raise CasError("unsucessful_response", "Unsuccessful CAS response") + + # Iterate through the nodes and pull out the user and any extra attributes. user = None attributes = {} - try: - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise Exception("root of CAS response is not serviceResponse") - success = root[0].tag.endswith("authenticationSuccess") - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - for attribute in child: - # ElementTree library expands the namespace in - # attribute tags to the full URL of the namespace. - # We don't care about namespace here and it will always - # be encased in curly braces, so we remove them. - tag = attribute.tag - if "}" in tag: - tag = tag.split("}")[1] - attributes[tag] = attribute.text - if user is None: - raise Exception("CAS response does not contain user") - except Exception: - logger.exception("Error parsing CAS response") - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not success: - raise LoginError( - 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED - ) - return user, attributes + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + + # Ensure a user was found. + if user is None: + raise CasError("no_user", "CAS response does not contain user") + + return CasResponse(user, attributes) def get_redirect_url(self, service_args: Dict[str, str]) -> str: """ @@ -198,31 +235,150 @@ class CasHandler: args["redirectUrl"] = client_redirect_url if session: args["session"] = session - username, user_display_name = await self._validate_ticket(ticket, args) - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) + try: + cas_response = await self._validate_ticket(ticket, args) + except CasError as e: + logger.exception("Could not validate ticket") + self._sso_handler.render_error(request, e.error, e.error_description, 401) + return + + await self._handle_cas_response( + request, cas_response, client_redirect_url, session + ) + + async def _handle_cas_response( + self, + request: SynapseRequest, + cas_response: CasResponse, + client_redirect_url: Optional[str], + session: Optional[str], + ) -> None: + """Handle a CAS response to a ticket request. + + Assumes that the response has been validated. Maps the user onto an MXID, + registering them if necessary, and returns a response to the browser. + + Args: + request: the incoming request from the browser. We'll respond to it with an + HTML page or a redirect + + cas_response: The parsed CAS response. + + client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given. + This should be the same as the redirectUrl from the original `/login/sso/redirect` request. + + session: The session parameter from the `/cas/ticket` HTTP request, if given. + This should be the UI Auth session id. + """ + # first check if we're doing a UIA if session: - await self._auth_handler.complete_sso_ui_auth( - registered_user_id, session, request, + return await self._sso_handler.complete_sso_ui_auth_request( + self._auth_provider_id, cas_response.username, session, request, ) - else: - if not registered_user_id: - # Pull out the user-agent and IP from the request. - user_agent = request.requestHeaders.getRawHeaders( - b"User-Agent", default=[b""] - )[0].decode("ascii", "surrogateescape") - ip_address = self.hs.get_ip_from_request(request) - - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=user_display_name, - user_agent_ips=(user_agent, ip_address), + # otherwise, we're handling a login request. + + # Ensure that the attributes of the logged in user meet the required + # attributes. + for required_attribute, required_value in self._cas_required_attributes.items(): + # If required attribute was not in CAS Response - Forbidden + if required_attribute not in cas_response.attributes: + self._sso_handler.render_error( + request, + "unauthorised", + "You are not authorised to log in here.", + 401, ) + return - await self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url + # Also need to check value + if required_value is not None: + actual_value = cas_response.attributes[required_attribute] + # If required attribute value does not match expected - Forbidden + if required_value != actual_value: + self._sso_handler.render_error( + request, + "unauthorised", + "You are not authorised to log in here.", + 401, + ) + return + + # Call the mapper to register/login the user + + # If this not a UI auth request than there must be a redirect URL. + assert client_redirect_url is not None + + try: + await self._complete_cas_login(cas_response, request, client_redirect_url) + except MappingException as e: + logger.exception("Could not map user") + self._sso_handler.render_error(request, "mapping_error", str(e)) + + async def _complete_cas_login( + self, + cas_response: CasResponse, + request: SynapseRequest, + client_redirect_url: str, + ) -> None: + """ + Given a CAS response, complete the login flow + + Retrieves the remote user ID, registers the user if necessary, and serves + a redirect back to the client with a login-token. + + Args: + cas_response: The parsed CAS response. + request: The request to respond to + client_redirect_url: The redirect URL passed in by the client. + + Raises: + MappingException if there was a problem mapping the response to a user. + RedirectException: some mapping providers may raise this if they need + to redirect to an interstitial page. + """ + # Note that CAS does not support a mapping provider, so the logic is hard-coded. + localpart = map_username_to_mxid_localpart(cas_response.username) + + async def cas_response_to_user_attributes(failures: int) -> UserAttributes: + """ + Map from CAS attributes to user attributes. + """ + # Due to the grandfathering logic matching any previously registered + # mxids it isn't expected for there to be any failures. + if failures: + raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs") + + display_name = cas_response.attributes.get( + self._cas_displayname_attribute, None ) + + return UserAttributes(localpart=localpart, display_name=display_name) + + async def grandfather_existing_users() -> Optional[str]: + # Since CAS did not always use the user_external_ids table, always + # to attempt to map to existing users. + user_id = UserID(localpart, self._hostname).to_string() + + logger.debug( + "Looking for existing account based on mapped %s", user_id, + ) + + users = await self._store.get_users_by_id_case_insensitive(user_id) + if users: + registered_user_id = list(users.keys())[0] + logger.info("Grandfathering mapping to %s", registered_user_id) + return registered_user_id + + return None + + await self._sso_handler.complete_sso_login_request( + self._auth_provider_id, + cas_response.username, + request, + client_redirect_url, + cas_response_to_user_attributes, + grandfather_existing_users, + ) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 0635ad5708..e808142365 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py
@@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process @@ -22,27 +22,31 @@ from synapse.types import UserID, create_requester from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class DeactivateAccountHandler(BaseHandler): """Handler which deals with deactivating user accounts.""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() - self._identity_handler = hs.get_handlers().identity_handler + self._identity_handler = hs.get_identity_handler() self.user_directory_handler = hs.get_user_directory_handler() + self._server_name = hs.hostname # Flag that indicates whether the process to part users from rooms is running self._user_parter_running = False # Start the user parter loop so it can resume parting users from rooms where # it left off (if it has work left to do). - if hs.config.worker_app is None: + if hs.config.run_background_tasks: hs.get_reactor().callWhenRunning(self._start_user_parting) self._account_validity_enabled = hs.config.account_validity.enabled @@ -137,7 +141,7 @@ class DeactivateAccountHandler(BaseHandler): return identity_server_supports_unbinding - async def _reject_pending_invites_for_user(self, user_id: str): + async def _reject_pending_invites_for_user(self, user_id: str) -> None: """Reject pending invites addressed to a given user ID. Args: @@ -149,7 +153,7 @@ class DeactivateAccountHandler(BaseHandler): for room in pending_invites: try: await self._room_member_handler.update_membership( - create_requester(user), + create_requester(user, authenticated_entity=self._server_name), user, room.room_id, "leave", @@ -205,7 +209,7 @@ class DeactivateAccountHandler(BaseHandler): logger.info("User parter parting %r from %r", user_id, room_id) try: await self._room_member_handler.update_membership( - create_requester(user), + create_requester(user, authenticated_entity=self._server_name), user, room_id, "leave", diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b9d9098104..debb1b4f29 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -29,7 +29,10 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( + Collection, + JsonDict, StreamToken, + UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, ) @@ -41,13 +44,16 @@ from synapse.util.retryutils import NotRetryingDestination from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) MAX_DEVICE_DISPLAY_NAME_LEN = 100 class DeviceWorkerHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs @@ -105,7 +111,9 @@ class DeviceWorkerHandler(BaseHandler): @trace @measure_func("device.get_user_ids_changed") - async def get_user_ids_changed(self, user_id: str, from_token: StreamToken): + async def get_user_ids_changed( + self, user_id: str, from_token: StreamToken + ) -> JsonDict: """Get list of users that have had the devices updated, or have newly joined a room, that `user_id` may be interested in. """ @@ -221,8 +229,8 @@ class DeviceWorkerHandler(BaseHandler): possibly_joined = possibly_changed & users_who_share_room possibly_left = (possibly_changed | possibly_left) - users_who_share_room else: - possibly_joined = [] - possibly_left = [] + possibly_joined = set() + possibly_left = set() result = {"changed": list(possibly_joined), "left": list(possibly_left)} @@ -230,7 +238,7 @@ class DeviceWorkerHandler(BaseHandler): return result - async def on_federation_query_user_devices(self, user_id): + async def on_federation_query_user_devices(self, user_id: str) -> JsonDict: stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query( user_id ) @@ -249,7 +257,7 @@ class DeviceWorkerHandler(BaseHandler): class DeviceHandler(DeviceWorkerHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.federation_sender = hs.get_federation_sender() @@ -264,7 +272,7 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) - def _check_device_name_length(self, name: str): + def _check_device_name_length(self, name: Optional[str]): """ Checks whether a device name is longer than the maximum allowed length. @@ -283,8 +291,11 @@ class DeviceHandler(DeviceWorkerHandler): ) async def check_device_registered( - self, user_id, device_id, initial_device_display_name=None - ): + self, + user_id: str, + device_id: Optional[str], + initial_device_display_name: Optional[str] = None, + ) -> str: """ If the given device has not been registered, register it with the supplied display name. @@ -292,12 +303,11 @@ class DeviceHandler(DeviceWorkerHandler): If no device_id is supplied, we make one up. Args: - user_id (str): @user:id - device_id (str | None): device id supplied by client - initial_device_display_name (str | None): device display name from - client + user_id: @user:id + device_id: device id supplied by client + initial_device_display_name: device display name from client Returns: - str: device id (generated if none was supplied) + device id (generated if none was supplied) """ self._check_device_name_length(initial_device_display_name) @@ -316,15 +326,15 @@ class DeviceHandler(DeviceWorkerHandler): # times in case of a clash. attempts = 0 while attempts < 5: - device_id = stringutils.random_string(10).upper() + new_device_id = stringutils.random_string(10).upper() new_device = await self.store.store_device( user_id=user_id, - device_id=device_id, + device_id=new_device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - await self.notify_device_update(user_id, [device_id]) - return device_id + await self.notify_device_update(user_id, [new_device_id]) + return new_device_id attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @@ -433,7 +443,9 @@ class DeviceHandler(DeviceWorkerHandler): @trace @measure_func("notify_device_update") - async def notify_device_update(self, user_id, device_ids): + async def notify_device_update( + self, user_id: str, device_ids: Collection[str] + ) -> None: """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. """ @@ -445,7 +457,7 @@ class DeviceHandler(DeviceWorkerHandler): user_id ) - hosts = set() + hosts = set() # type: Set[str] if self.hs.is_mine_id(user_id): hosts.update(get_domain_from_id(u) for u in users_who_share_room) hosts.discard(self.server_name) @@ -497,7 +509,7 @@ class DeviceHandler(DeviceWorkerHandler): self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) - async def user_left_room(self, user, room_id): + async def user_left_room(self, user: UserID, room_id: str) -> None: user_id = user.to_string() room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: @@ -505,8 +517,89 @@ class DeviceHandler(DeviceWorkerHandler): # receive device updates. Mark this in DB. await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + async def store_dehydrated_device( + self, + user_id: str, + device_data: JsonDict, + initial_device_display_name: Optional[str] = None, + ) -> str: + """Store a dehydrated device for a user. If the user had a previous + dehydrated device, it is removed. + + Args: + user_id: the user that we are storing the device for + device_data: the dehydrated device information + initial_device_display_name: The display name to use for the device + Returns: + device id of the dehydrated device + """ + device_id = await self.check_device_registered( + user_id, None, initial_device_display_name, + ) + old_device_id = await self.store.store_dehydrated_device( + user_id, device_id, device_data + ) + if old_device_id is not None: + await self.delete_device(user_id, old_device_id) + return device_id + + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + + async def rehydrate_device( + self, user_id: str, access_token: str, device_id: str + ) -> dict: + """Process a rehydration request from the user. + + Args: + user_id: the user who is rehydrating the device + access_token: the access token used for the request + device_id: the ID of the device that will be rehydrated + Returns: + a dict containing {"success": True} + """ + success = await self.store.remove_dehydrated_device(user_id, device_id) + + if not success: + raise errors.NotFoundError() + + # If the dehydrated device was successfully deleted (the device ID + # matched the stored dehydrated device), then modify the access + # token to use the dehydrated device's ID and copy the old device + # display name to the dehydrated device, and destroy the old device + # ID + old_device_id = await self.store.set_device_for_access_token( + access_token, device_id + ) + old_device = await self.store.get_device(user_id, old_device_id) + await self.store.update_device(user_id, device_id, old_device["display_name"]) + # can't call self.delete_device because that will clobber the + # access token so call the storage layer directly + await self.store.delete_device(user_id, old_device_id) + await self.store.delete_e2e_keys_by_device( + user_id=user_id, device_id=old_device_id + ) -def _update_device_from_client_ips(device, client_ips): + # tell everyone that the old device is gone and that the dehydrated + # device has a new display name + await self.notify_device_update(user_id, [old_device_id, device_id]) + + return {"success": True} + + +def _update_device_from_client_ips( + device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]] +) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) @@ -514,7 +607,7 @@ def _update_device_from_client_ips(device, client_ips): class DeviceListUpdater: "Handles incoming device list updates from federation and updates the DB" - def __init__(self, hs, device_handler): + def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.clock = hs.get_clock() @@ -523,7 +616,9 @@ class DeviceListUpdater: self._remote_edu_linearizer = Linearizer(name="remote_device_list") # user_id -> list of updates waiting to be handled. - self._pending_updates = {} + self._pending_updates = ( + {} + ) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]] # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious @@ -546,7 +641,9 @@ class DeviceListUpdater: ) @trace - async def incoming_device_list_update(self, origin, edu_content): + async def incoming_device_list_update( + self, origin: str, edu_content: JsonDict + ) -> None: """Called on incoming device list update from federation. Responsible for parsing the EDU and adding to pending updates list. """ @@ -607,7 +704,7 @@ class DeviceListUpdater: await self._handle_device_updates(user_id) @measure_func("_incoming_device_list_update") - async def _handle_device_updates(self, user_id): + async def _handle_device_updates(self, user_id: str) -> None: "Actually handle pending updates." with (await self._remote_edu_linearizer.queue(user_id)): @@ -655,7 +752,9 @@ class DeviceListUpdater: stream_id for _, stream_id, _, _ in pending_updates ) - async def _need_to_do_resync(self, user_id, updates): + async def _need_to_do_resync( + self, user_id: str, updates: Iterable[Tuple[str, str, Iterable[str], JsonDict]] + ) -> bool: """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ @@ -686,7 +785,7 @@ class DeviceListUpdater: return False @trace - async def _maybe_retry_device_resync(self): + async def _maybe_retry_device_resync(self) -> None: """Retry to resync device lists that are out of sync, except if another retry is in progress. """ @@ -729,7 +828,7 @@ class DeviceListUpdater: async def user_device_resync( self, user_id: str, mark_failed_as_stale: bool = True - ) -> Optional[dict]: + ) -> Optional[JsonDict]: """Fetches all devices for a user and updates the device cache with them. Args: @@ -753,7 +852,7 @@ class DeviceListUpdater: # it later. await self.store.mark_remote_user_device_cache_as_stale(user_id) - return + return None except (RequestSendFailed, HttpResponseException) as e: logger.warning( "Failed to handle device list update for %s: %s", user_id, e, @@ -770,12 +869,12 @@ class DeviceListUpdater: # next time we get a device list update for this user_id. # This makes it more likely that the device lists will # eventually become consistent. - return + return None except FederationDeniedError as e: set_tag("error", True) log_kv({"reason": "FederationDeniedError"}) logger.info(e) - return + return None except Exception as e: set_tag("error", True) log_kv( @@ -788,7 +887,7 @@ class DeviceListUpdater: # it later. await self.store.mark_remote_user_device_cache_as_stale(user_id) - return + return None log_kv({"result": result}) stream_id = result["stream_id"] devices = result["devices"] @@ -849,7 +948,7 @@ class DeviceListUpdater: user_id: str, master_key: Optional[Dict[str, Any]], self_signing_key: Optional[Dict[str, Any]], - ) -> list: + ) -> List[str]: """Process the given new master and self-signing key for the given remote user. Args: diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 64ef7f63ab..9cac5a8463 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Dict +from typing import TYPE_CHECKING, Any, Dict from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background @@ -24,18 +24,22 @@ from synapse.logging.opentracing import ( set_tag, start_active_span, ) -from synapse.types import UserID, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util import json_encoder from synapse.util.stringutils import random_string +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + + logger = logging.getLogger(__name__) class DeviceMessageHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): """ Args: - hs (synapse.server.HomeServer): server + hs: server """ self.store = hs.get_datastore() self.notifier = hs.get_notifier() @@ -48,7 +52,7 @@ class DeviceMessageHandler: self._device_list_updater = hs.get_device_handler().device_list_updater - async def on_direct_to_device_edu(self, origin, content): + async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): @@ -95,7 +99,7 @@ class DeviceMessageHandler: message_type: str, sender_user_id: str, by_device: Dict[str, Dict[str, Any]], - ): + ) -> None: """Checks inbound device messages for unknown remote devices, and if found marks the remote cache for the user as stale. """ @@ -138,11 +142,16 @@ class DeviceMessageHandler: self._device_list_updater.user_device_resync, sender_user_id ) - async def send_device_message(self, sender_user_id, message_type, messages): + async def send_device_message( + self, + sender_user_id: str, + message_type: str, + messages: Dict[str, Dict[str, JsonDict]], + ) -> None: set_tag("number_of_messages", len(messages)) set_tag("sender", sender_user_id) local_messages = {} - remote_messages = {} + remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] for user_id, by_device in messages.items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 62aa9a2da8..abcf86352d 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py
@@ -46,6 +46,7 @@ class DirectoryHandler(BaseHandler): self.config = hs.config self.enable_room_list_search = hs.config.enable_room_list_search self.require_membership = hs.config.require_membership_for_aliases + self.third_party_event_rules = hs.get_third_party_event_rules() self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -132,7 +133,9 @@ class DirectoryHandler(BaseHandler): 403, "You must be in the room to create an alias for it" ) - if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): + if not await self.spam_checker.user_may_create_room_alias( + user_id, room_alias + ): raise AuthError(403, "This user is not permitted to create this alias") if not self.config.is_alias_creation_allowed( @@ -383,7 +386,7 @@ class DirectoryHandler(BaseHandler): """ creator = await self.store.get_room_alias_creator(alias.to_string()) - if creator is not None and creator == user_id: + if creator == user_id: return True # Resolve the alias to the corresponding room. @@ -408,7 +411,7 @@ class DirectoryHandler(BaseHandler): """ user_id = requester.user.to_string() - if not self.spam_checker.user_may_publish_room(user_id, room_id): + if not await self.spam_checker.user_may_publish_room(user_id, room_id): raise AuthError( 403, "This user is not permitted to publish rooms to the room list" ) @@ -454,6 +457,15 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to publish room") + # Check if publishing is blocked by a third party module + allowed_by_third_party_rules = await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) + ) + if not allowed_by_third_party_rules: + raise SynapseError(403, "Not allowed to publish room") + await self.store.set_room_is_public(room_id, making_public) async def edit_published_appservice_room_list( diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index dd40fd1299..929752150d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -129,6 +129,11 @@ class E2eKeysHandler: if user_id in local_query: results[user_id] = keys + # Get cached cross-signing keys + cross_signing_keys = await self.get_cross_signing_keys_from_cache( + device_keys_query, from_user_id + ) + # Now attempt to get any remote devices from our local cache. remote_queries_not_in_cache = {} if remote_queries: @@ -155,16 +160,28 @@ class E2eKeysHandler: unsigned["device_display_name"] = device_display_name user_devices[device_id] = result + # check for missing cross-signing keys. + for user_id in remote_queries.keys(): + cached_cross_master = user_id in cross_signing_keys["master_keys"] + cached_cross_selfsigning = ( + user_id in cross_signing_keys["self_signing_keys"] + ) + + # check if we are missing only one of cross-signing master or + # self-signing key, but the other one is cached. + # as we need both, this will issue a federation request. + # if we don't have any of the keys, either the user doesn't have + # cross-signing set up, or the cached device list + # is not (yet) updated. + if cached_cross_master ^ cached_cross_selfsigning: + user_ids_not_in_cache.add(user_id) + + # add those users to the list to fetch over federation. for user_id in user_ids_not_in_cache: domain = get_domain_from_id(user_id) r = remote_queries_not_in_cache.setdefault(domain, {}) r[user_id] = remote_queries[user_id] - # Get cached cross-signing keys - cross_signing_keys = await self.get_cross_signing_keys_from_cache( - device_keys_query, from_user_id - ) - # Now fetch any devices that we don't have in our cache @trace async def do_remote_query(destination): @@ -496,6 +513,22 @@ class E2eKeysHandler: log_kv( {"message": "Did not update one_time_keys", "reason": "no keys given"} ) + fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None) + if fallback_keys and isinstance(fallback_keys, dict): + log_kv( + { + "message": "Updating fallback_keys for device.", + "user_id": user_id, + "device_id": device_id, + } + ) + await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys) + elif fallback_keys: + log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"}) + else: + log_kv( + {"message": "Did not update fallback_keys", "reason": "no keys given"} + ) # the device should have been registered already, but it may have been # deleted due to a race with a DELETE request. Or we may be using an diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1a8144405a..fd8de8696d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -55,6 +55,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.handlers._base import BaseHandler +from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -67,7 +68,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationFederationSendEventsRestServlet, - ReplicationStoreRoomOnInviteRestServlet, + ReplicationStoreRoomOnOutlierMembershipRestServlet, ) from synapse.state import StateResolutionStore from synapse.storage.databases.main.events_worker import EventRedactBehaviour @@ -112,7 +113,7 @@ class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: a) handling received Pdus before handing them on as Events to the rest - of the homeserver (including auth and state conflict resoultion) + of the homeserver (including auth and state conflict resolutions) b) converting events that were produced by local clients that may need to be sent to remote homeservers. c) doing the necessary dances to invite remote users and join remote @@ -139,7 +140,7 @@ class FederationHandler(BaseHandler): self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config - self.http_client = hs.get_simple_http_client() + self.http_client = hs.get_proxied_blacklisted_http_client() self._instance_name = hs.get_instance_name() self._replication = hs.get_replication_data_handler() @@ -152,12 +153,14 @@ class FederationHandler(BaseHandler): self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( hs ) - self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client( + self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client( hs ) else: self._device_list_updater = hs.get_device_handler().device_list_updater - self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite + self._maybe_store_room_on_outlier_membership = ( + self.store.maybe_store_room_on_outlier_membership + ) # When joining a room we need to queue any events for that room up. # For each room, a list of (pdu, origin) tuples. @@ -477,7 +480,7 @@ class FederationHandler(BaseHandler): # ---- # # Update richvdh 2018/09/18: There are a number of problems with timing this - # request out agressively on the client side: + # request out aggressively on the client side: # # - it plays badly with the server-side rate-limiter, which starts tarpitting you # if you send too many requests at once, so you end up with the server carefully @@ -495,13 +498,13 @@ class FederationHandler(BaseHandler): # we'll end up back here for the *next* PDU in the list, which exacerbates the # problem. # - # - the agressive 10s timeout was introduced to deal with incoming federation + # - the aggressive 10s timeout was introduced to deal with incoming federation # requests taking 8 hours to process. It's not entirely clear why that was going # on; certainly there were other issues causing traffic storms which are now # resolved, and I think in any case we may be more sensible about our locking # now. We're *certainly* more sensible about our logging. # - # All that said: Let's try increasing the timout to 60s and see what happens. + # All that said: Let's try increasing the timeout to 60s and see what happens. try: missing_events = await self.federation_client.get_missing_events( @@ -1120,7 +1123,7 @@ class FederationHandler(BaseHandler): logger.info(str(e)) continue except RequestSendFailed as e: - logger.info("Falied to get backfill from %s because %s", dom, e) + logger.info("Failed to get backfill from %s because %s", dom, e) continue except FederationDeniedError as e: logger.info(e) @@ -1507,18 +1510,9 @@ class FederationHandler(BaseHandler): event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - except AuthError as e: + except SynapseError as e: logger.warning("Failed to create join to %s because %s", room_id, e) - raise e - - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info("Creation of join %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) + raise # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` @@ -1554,7 +1548,7 @@ class FederationHandler(BaseHandler): # # The reasons we have the destination server rather than the origin # server send it are slightly mysterious: the origin server should have - # all the neccessary state once it gets the response to the send_join, + # all the necessary state once it gets the response to the send_join, # so it could send the event itself if it wanted to. It may be that # doing it this way reduces failure modes, or avoids certain attacks # where a new server selectively tells a subset of the federation that @@ -1567,15 +1561,6 @@ class FederationHandler(BaseHandler): context = await self._handle_new_event(origin, event) - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info("Sending of join %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - logger.debug( "on_send_join_request: After _handle_new_event: %s, sigs: %s", event.event_id, @@ -1608,7 +1593,7 @@ class FederationHandler(BaseHandler): if self.hs.config.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") - if not self.spam_checker.user_may_invite( + if not await self.spam_checker.user_may_invite( event.sender, event.state_key, event.room_id ): raise SynapseError( @@ -1635,7 +1620,7 @@ class FederationHandler(BaseHandler): # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a # join dance). - await self._maybe_store_room_on_invite( + await self._maybe_store_room_on_outlier_membership( room_id=event.room_id, room_version=room_version ) @@ -1667,7 +1652,7 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True - # Try the host that we succesfully called /make_leave/ on first for + # Try the host that we successfully called /make_leave/ on first for # the /send_leave/ request. host_list = list(target_hosts) try: @@ -1748,15 +1733,6 @@ class FederationHandler(BaseHandler): builder=builder ) - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.warning("Creation of leave %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` @@ -1789,16 +1765,7 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = await self._handle_new_event(origin, event) - - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info("Sending of leave %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) + await self._handle_new_event(origin, event) logger.debug( "on_send_leave_request: After _handle_new_event: %s, sigs: %s", @@ -2694,18 +2661,6 @@ class FederationHandler(BaseHandler): builder=builder ) - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info( - "Creation of threepid invite %s forbidden by third-party rules", - event, - ) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - event, context = await self.add_display_name_to_third_party_invite( room_version, event_dict, event, context ) @@ -2734,7 +2689,7 @@ class FederationHandler(BaseHandler): ) async def on_exchange_third_party_invite_request( - self, room_id: str, event_dict: JsonDict + self, event_dict: JsonDict ) -> None: """Handle an exchange_third_party_invite request from a remote server @@ -2742,12 +2697,11 @@ class FederationHandler(BaseHandler): into a normal m.room.member invite. Args: - room_id: The ID of the room. - - event_dict (dict[str, Any]): Dictionary containing the event body. + event_dict: Dictionary containing the event body. """ - room_version = await self.store.get_room_version_id(room_id) + assert_params_in_dict(event_dict, ["room_id"]) + room_version = await self.store.get_room_version_id(event_dict["room_id"]) # NB: event_dict has a particular specced format we might need to fudge # if we change event formats too much. @@ -2756,18 +2710,6 @@ class FederationHandler(BaseHandler): event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.warning( - "Exchange of threepid invite %s forbidden by third-party rules", event - ) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - event, context = await self.add_display_name_to_third_party_invite( room_version, event_dict, event, context ) @@ -2966,17 +2908,20 @@ class FederationHandler(BaseHandler): return result["max_stream_id"] else: assert self.storage.persistence - max_stream_token = await self.storage.persistence.persist_events( + + # Note that this returns the events that were persisted, which may not be + # the same as were passed in if some were deduplicated due to transaction IDs. + events, max_stream_token = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) if self._ephemeral_messages_enabled: - for (event, context) in event_and_contexts: + for event in events: # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) if not backfilled: # Never notify for backfilled events - for event, _ in event_and_contexts: + for event in events: await self._notify_persisted_event(event, max_stream_token) return max_stream_token.stream @@ -3008,6 +2953,9 @@ class FederationHandler(BaseHandler): elif event.internal_metadata.is_outlier(): return + # the event has been persisted so it should have a stream ordering. + assert event.internal_metadata.stream_ordering + event_pos = PersistedEventPosition( self._instance_name, event.internal_metadata.stream_ordering ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 9684e60fc8..df29edeb83 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py
@@ -17,7 +17,7 @@ import logging from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.types import get_domain_from_id +from synapse.types import GroupID, get_domain_from_id logger = logging.getLogger(__name__) @@ -28,6 +28,9 @@ def _create_rerouter(func_name): """ async def f(self, group_id, *args, **kwargs): + if not GroupID.is_valid(group_id): + raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) + if self.is_mine_id(group_id): return await getattr(self.groups_server_handler, func_name)( group_id, *args, **kwargs @@ -346,7 +349,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): server_name=get_domain_from_id(group_id), ) - # TODO: Check that the group is public and we're being added publically + # TODO: Check that the group is public and we're being added publicly is_publicised = content.get("publicise", False) token = await self.store.register_user_group_membership( @@ -391,7 +394,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): server_name=get_domain_from_id(group_id), ) - # TODO: Check that the group is public and we're being added publically + # TODO: Check that the group is public and we're being added publicly is_publicised = content.get("publicise", False) token = await self.store.register_user_group_membership( diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index bc3e9607ca..c05036ad1f 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py
@@ -46,15 +46,17 @@ class IdentityHandler(BaseHandler): def __init__(self, hs): super().__init__(hs) + # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) - # We create a blacklisting instance of SimpleHttpClient for contacting identity - # servers specified by clients + # An HTTP client for contacting identity servers specified by clients. self.blacklisting_http_client = SimpleHttpClient( hs, ip_blacklist=hs.config.federation_ip_range_blacklist ) - self.federation_http_client = hs.get_http_client() + self.federation_http_client = hs.get_federation_http_client() self.hs = hs + self._web_client_location = hs.config.invite_client_location + async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] ) -> Optional[JsonDict]: @@ -354,7 +356,8 @@ class IdentityHandler(BaseHandler): raise SynapseError(500, "An error was encountered when sending the email") token_expires = ( - self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime + self.hs.get_clock().time_msec() + + self.hs.config.email_validation_token_lifetime ) await self.store.start_or_continue_validation_session( @@ -802,6 +805,9 @@ class IdentityHandler(BaseHandler): "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, } + # If a custom web client location is available, include it in the request. + if self._web_client_location: + invite_config["org.matrix.web_client_location"] = self._web_client_location # Add the identity service access token to the JSON body and use the v2 # Identity Service endpoints if id_access_token is present diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 39a85801c1..fbd8df9dcc 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple from twisted.internet import defer @@ -47,12 +47,14 @@ class InitialSyncHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") + self.snapshot_cache = ResponseCache( + hs, "initial_sync_cache" + ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]] self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state - def snapshot_all_rooms( + async def snapshot_all_rooms( self, user_id: str, pagin_config: PaginationConfig, @@ -84,7 +86,7 @@ class InitialSyncHandler(BaseHandler): include_archived, ) - return self.snapshot_cache.wrap( + return await self.snapshot_cache.wrap( key, self._snapshot_all_rooms, user_id, @@ -291,6 +293,10 @@ class InitialSyncHandler(BaseHandler): user_id, room_id, pagin_config, membership, is_peeking ) elif membership == Membership.LEAVE: + # The member_event_id will always be available if membership is set + # to leave. + assert member_event_id + result = await self._room_initial_sync_parted( user_id, room_id, pagin_config, membership, member_event_id, is_peeking ) @@ -313,13 +319,11 @@ class InitialSyncHandler(BaseHandler): user_id: str, room_id: str, pagin_config: PaginationConfig, - membership: Membership, + membership: str, member_event_id: str, is_peeking: bool, ) -> JsonDict: - room_state = await self.state_store.get_state_for_events([member_event_id]) - - room_state = room_state[member_event_id] + room_state = await self.state_store.get_state_for_event(member_event_id) limit = pagin_config.limit if pagin_config else None if limit is None: @@ -365,7 +369,7 @@ class InitialSyncHandler(BaseHandler): user_id: str, room_id: str, pagin_config: PaginationConfig, - membership: Membership, + membership: str, is_peeking: bool, ) -> JsonDict: current_state = await self.state.get_current_state(room_id=room_id) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ee271e85e5..9dfeab09cd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -50,15 +50,15 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester -from synapse.util import json_decoder +from synapse.util import json_decoder, json_encoder from synapse.util.async_helpers import Linearizer -from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client from ._base import BaseHandler if TYPE_CHECKING: + from synapse.events.third_party_rules import ThirdPartyEventRules from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -393,27 +393,31 @@ class EventCreationHandler: self.action_generator = hs.get_action_generator() self.spam_checker = hs.get_spam_checker() - self.third_party_event_rules = hs.get_third_party_event_rules() + self.third_party_event_rules = ( + self.hs.get_third_party_event_rules() + ) # type: ThirdPartyEventRules self._block_events_without_consent_error = ( self.config.block_events_without_consent_error ) + # we need to construct a ConsentURIBuilder here, as it checks that the necessary + # config options, but *only* if we have a configuration for which we are + # going to need it. + if self._block_events_without_consent_error: + self._consent_uri_builder = ConsentURIBuilder(self.config) + # Rooms which should be excluded from dummy insertion. (For instance, # those without local users who can send events into the room). # # map from room id to time-of-last-attempt. # self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int] - - # we need to construct a ConsentURIBuilder here, as it checks that the necessary - # config options, but *only* if we have a configuration for which we are - # going to need it. - if self._block_events_without_consent_error: - self._consent_uri_builder = ConsentURIBuilder(self.config) + # The number of forward extremeities before a dummy event is sent. + self._dummy_events_threshold = hs.config.dummy_events_threshold if ( - not self.config.worker_app + self.config.run_background_tasks and self.config.cleanup_extremities_with_dummy_events ): self.clock.looping_call( @@ -428,15 +432,13 @@ class EventCreationHandler: self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages - self._dummy_events_threshold = hs.config.dummy_events_threshold - async def create_event( self, requester: Requester, event_dict: dict, - token_id: Optional[str] = None, txn_id: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, + auth_event_ids: Optional[List[str]] = None, require_consent: bool = True, ) -> Tuple[EventBase, EventContext]: """ @@ -450,13 +452,18 @@ class EventCreationHandler: Args: requester event_dict: An entire event - token_id txn_id prev_event_ids: the forward extremities to use as the prev_events for the new event. If None, they will be requested from the database. + + auth_event_ids: + The event ids to use as the auth_events for the new event. + Should normally be left as None, which will cause them to be calculated + based on the room state at the prev_events. + require_consent: Whether to check if the requester has consented to the privacy policy. Raises: @@ -465,7 +472,7 @@ class EventCreationHandler: Returns: Tuple of created event, Context """ - await self.auth.check_auth_blocking(requester.user.to_string()) + await self.auth.check_auth_blocking(requester=requester) if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version = event_dict["content"]["room_version"] @@ -508,14 +515,17 @@ class EventCreationHandler: if require_consent and not is_exempt: await self.assert_accepted_privacy_policy(requester) - if token_id is not None: - builder.internal_metadata.token_id = token_id + if requester.access_token_id is not None: + builder.internal_metadata.token_id = requester.access_token_id if txn_id is not None: builder.internal_metadata.txn_id = txn_id event, context = await self.create_new_client_event( - builder=builder, requester=requester, prev_event_ids=prev_event_ids, + builder=builder, + requester=requester, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -609,7 +619,13 @@ class EventCreationHandler: if requester.app_service is not None: return - user_id = requester.user.to_string() + user_id = requester.authenticated_entity + if not user_id.startswith("@"): + # The authenticated entity might not be a user, e.g. if it's the + # server puppetting the user. + return + + user = UserID.from_string(user_id) # exempt the system notices user if ( @@ -629,65 +645,10 @@ class EventCreationHandler: if u["consent_version"] == self.config.user_consent_version: return - consent_uri = self._consent_uri_builder.build_user_consent_uri( - requester.user.localpart - ) + consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart) msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) - async def send_nonmember_event( - self, - requester: Requester, - event: EventBase, - context: EventContext, - ratelimit: bool = True, - ignore_shadow_ban: bool = False, - ) -> int: - """ - Persists and notifies local clients and federation of an event. - - Args: - requester: The requester sending the event. - event: The event to send. - context: The context of the event. - ratelimit: Whether to rate limit this send. - ignore_shadow_ban: True if shadow-banned users should be allowed to - send this event. - - Return: - The stream_id of the persisted event. - - Raises: - ShadowBanError if the requester has been shadow-banned. - """ - if event.type == EventTypes.Member: - raise SynapseError( - 500, "Tried to send member event through non-member codepath" - ) - - if not ignore_shadow_ban and requester.shadow_banned: - # We randomly sleep a bit just to annoy the requester. - await self.clock.sleep(random.randint(1, 10)) - raise ShadowBanError() - - user = UserID.from_string(event.sender) - - assert self.hs.is_mine(user), "User must be our own: %s" % (user,) - - if event.is_state(): - prev_event = await self.deduplicate_state_event(event, context) - if prev_event is not None: - logger.info( - "Not bothering to persist state event %s duplicated by %s", - event.event_id, - prev_event.event_id, - ) - return await self.store.get_stream_id_for_event(prev_event.event_id) - - return await self.handle_new_client_event( - requester=requester, event=event, context=context, ratelimit=ratelimit - ) - async def deduplicate_state_event( self, event: EventBase, context: EventContext ) -> Optional[EventBase]: @@ -699,7 +660,7 @@ class EventCreationHandler: context: The event context. Returns: - The previous verion of the event is returned, if it is found in the + The previous version of the event is returned, if it is found in the event context. Otherwise, None is returned. """ prev_state_ids = await context.get_prev_state_ids() @@ -728,7 +689,7 @@ class EventCreationHandler: """ Creates an event, then sends it. - See self.create_event and self.send_nonmember_event. + See self.create_event and self.handle_new_client_event. Args: requester: The requester sending the event. @@ -738,9 +699,19 @@ class EventCreationHandler: ignore_shadow_ban: True if shadow-banned users should be allowed to send this event. + Returns: + The event, and its stream ordering (if deduplication happened, + the previous, duplicate event). + Raises: ShadowBanError if the requester has been shadow-banned. """ + + if event_dict["type"] == EventTypes.Member: + raise SynapseError( + 500, "Tried to send member event through non-member codepath" + ) + if not ignore_shadow_ban and requester.shadow_banned: # We randomly sleep a bit just to annoy the requester. await self.clock.sleep(random.randint(1, 10)) @@ -752,24 +723,44 @@ class EventCreationHandler: # extremities to pile up, which in turn leads to state resolution # taking longer. with (await self.limiter.queue(event_dict["room_id"])): + if txn_id and requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + event_dict["room_id"], + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + if existing_event_id: + event = await self.store.get_event(existing_event_id) + # we know it was persisted, so must have a stream ordering + assert event.internal_metadata.stream_ordering + return event, event.internal_metadata.stream_ordering + event, context = await self.create_event( - requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id + requester, event_dict, txn_id=txn_id + ) + + assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( + event.sender, ) - spam_error = self.spam_checker.check_event_for_spam(event) + spam_error = await self.spam_checker.check_event_for_spam(event) if spam_error: if not isinstance(spam_error, str): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) - stream_id = await self.send_nonmember_event( - requester, - event, - context, + ev = await self.handle_new_client_event( + requester=requester, + event=event, + context=context, ratelimit=ratelimit, ignore_shadow_ban=ignore_shadow_ban, ) - return event, stream_id + + # we know it was persisted, so must have a stream ordering + assert ev.internal_metadata.stream_ordering + return ev, ev.internal_metadata.stream_ordering @measure_func("create_new_client_event") async def create_new_client_event( @@ -777,6 +768,7 @@ class EventCreationHandler: builder: EventBuilder, requester: Optional[Requester] = None, prev_event_ids: Optional[List[str]] = None, + auth_event_ids: Optional[List[str]] = None, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client @@ -789,6 +781,11 @@ class EventCreationHandler: If None, they will be requested from the database. + auth_event_ids: + The event ids to use as the auth_events for the new event. + Should normally be left as None, which will cause them to be calculated + based on the room state at the prev_events. + Returns: Tuple of created event, context """ @@ -810,11 +807,30 @@ class EventCreationHandler: builder.type == EventTypes.Create or len(prev_event_ids) > 0 ), "Attempting to create an event with no prev_events" - event = await builder.build(prev_event_ids=prev_event_ids) + event = await builder.build( + prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids + ) context = await self.state.compute_event_context(event) if requester: context.app_service = requester.app_service + third_party_result = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not third_party_result: + logger.info( + "Event %s forbidden by third-party rules", event, + ) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + elif isinstance(third_party_result, dict): + # the third-party rules want to replace the event. We'll need to build a new + # event. + event, context = await self._rebuild_event_after_third_party_rules( + third_party_result, event + ) + self.validator.validate_new(event, self.config) # If this event is an annotation then we check that that the sender @@ -843,8 +859,11 @@ class EventCreationHandler: context: EventContext, ratelimit: bool = True, extra_users: List[UserID] = [], - ) -> int: - """Processes a new event. This includes checking auth, persisting it, + ignore_shadow_ban: bool = False, + ) -> EventBase: + """Processes a new event. + + This includes deduplicating, checking auth, persisting, notifying users, sending to remote servers, etc. If called from a worker will hit out to the master process for final @@ -857,10 +876,39 @@ class EventCreationHandler: ratelimit extra_users: Any extra users to notify about event + ignore_shadow_ban: True if shadow-banned users should be allowed to + send this event. + Return: - The stream_id of the persisted event. + If the event was deduplicated, the previous, duplicate, event. Otherwise, + `event`. + + Raises: + ShadowBanError if the requester has been shadow-banned. """ + # we don't apply shadow-banning to membership events here. Invites are blocked + # higher up the stack, and we allow shadow-banned users to send join and leave + # events as normal. + if ( + event.type != EventTypes.Member + and not ignore_shadow_ban + and requester.shadow_banned + ): + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + + if event.is_state(): + prev_event = await self.deduplicate_state_event(event, context) + if prev_event is not None: + logger.info( + "Not bothering to persist state event %s duplicated by %s", + event.event_id, + prev_event.event_id, + ) + return prev_event + if event.is_state() and (event.type, event.state_key) == ( EventTypes.Create, "", @@ -869,14 +917,6 @@ class EventCreationHandler: else: room_version = await self.store.get_room_version_id(event.room_id) - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - if event.internal_metadata.is_out_of_band_membership(): # the only sort of out-of-band-membership events we expect to see here # are invite rejections we have generated ourselves. @@ -891,7 +931,7 @@ class EventCreationHandler: # Ensure that we can round trip before trying to persist in db try: - dump = frozendict_json_encoder.encode(event.content) + dump = json_encoder.encode(event.content) json_decoder.decode(dump) except Exception: logger.exception("Failed to encode content: %r", event.content) @@ -914,14 +954,24 @@ class EventCreationHandler: extra_users=extra_users, ) stream_id = result["stream_id"] - event.internal_metadata.stream_ordering = stream_id - return stream_id - - stream_id = await self.persist_and_notify_client_event( + event_id = result["event_id"] + if event_id != event.event_id: + # If we get a different event back then it means that its + # been de-duplicated, so we replace the given event with the + # one already persisted. + event = await self.store.get_event(event_id) + else: + # If we newly persisted the event then we need to update its + # stream_ordering entry manually (as it was persisted on + # another worker). + event.internal_metadata.stream_ordering = stream_id + return event + + event = await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return stream_id + return event except Exception: # Ensure that we actually remove the entries in the push actions # staging area, if we calculated them. @@ -966,11 +1016,16 @@ class EventCreationHandler: context: EventContext, ratelimit: bool = True, extra_users: List[UserID] = [], - ) -> int: + ) -> EventBase: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. This should only be run on the instance in charge of persisting events. + + Returns: + The persisted event. This may be different than the given event if + it was de-duplicated (e.g. because we had already persisted an + event with the same transaction ID.) """ assert self.storage.persistence is not None assert self._events_shard_config.should_handle( @@ -1018,7 +1073,7 @@ class EventCreationHandler: # Check the alias is currently valid (if it has changed). room_alias_str = event.content.get("alias", None) - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() if room_alias_str and room_alias_str != original_alias: await self._validate_canonical_alias( directory_handler, room_alias_str, event.room_id @@ -1044,38 +1099,17 @@ class EventCreationHandler: directory_handler, alias_str, event.room_id ) - federation_handler = self.hs.get_handlers().federation_handler + federation_handler = self.hs.get_federation_handler() if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: - - def is_inviter_member_event(e): - return e.type == EventTypes.Member and e.sender == event.sender - - current_state_ids = await context.get_current_state_ids() - - # We know this event is not an outlier, so this must be - # non-None. - assert current_state_ids is not None - - state_to_include_ids = [ - e_id - for k, e_id in current_state_ids.items() - if k[0] in self.room_invite_state_types - or k == (EventTypes.Member, event.sender) - ] - - state_to_include = await self.store.get_events(state_to_include_ids) - - event.unsigned["invite_room_state"] = [ - { - "type": e.type, - "state_key": e.state_key, - "content": e.content, - "sender": e.sender, - } - for e in state_to_include.values() - ] + event.unsigned[ + "invite_room_state" + ] = await self.store.get_stripped_room_state_from_event_context( + context, + self.room_invite_state_types, + membership_user_id=event.sender, + ) invitee = UserID.from_string(event.state_key) if not self.hs.is_mine(invitee): @@ -1108,6 +1142,9 @@ class EventCreationHandler: if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") + if original_event.type == EventTypes.ServerACL: + raise AuthError(403, "Redacting server ACL events is not permitted") + prev_state_ids = await context.get_prev_state_ids() auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True @@ -1138,9 +1175,13 @@ class EventCreationHandler: if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - event_pos, max_stream_token = await self.storage.persistence.persist_event( - event, context=context - ) + # Note that this returns the event that was persisted, which may not be + # the same as we passed in if it was deduplicated due transaction IDs. + ( + event, + event_pos, + max_stream_token, + ) = await self.storage.persistence.persist_event(event, context=context) if self._ephemeral_events_enabled: # If there's an expiry timestamp on the event, schedule its expiry. @@ -1161,7 +1202,7 @@ class EventCreationHandler: # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - return event_pos.stream + return event async def _bump_active_time(self, user: UserID) -> None: try: @@ -1215,12 +1256,12 @@ class EventCreationHandler: for user_id in members: if not self.hs.is_mine_id(user_id): continue - requester = create_requester(user_id) + requester = create_requester(user_id, authenticated_entity=self.server_name) try: event, context = await self.create_event( requester, { - "type": "org.matrix.dummy_event", + "type": EventTypes.Dummy, "content": {}, "room_id": room_id, "sender": user_id, @@ -1232,15 +1273,10 @@ class EventCreationHandler: # Since this is a dummy-event it is OK if it is sent by a # shadow-banned user. - await self.send_nonmember_event( + await self.handle_new_client_event( requester, event, context, ratelimit=False, ignore_shadow_ban=True, ) return True - except ConsentNotGivenError: - logger.info( - "Failed to send dummy event into room %s for user %s due to " - "lack of consent. Will try another user" % (room_id, user_id) - ) except AuthError: logger.info( "Failed to send dummy event into room %s for user %s due to " @@ -1260,3 +1296,62 @@ class EventCreationHandler: room_id, ) del self._rooms_to_exclude_from_dummy_event_insertion[room_id] + + async def _rebuild_event_after_third_party_rules( + self, third_party_result: dict, original_event: EventBase + ) -> Tuple[EventBase, EventContext]: + # the third_party_event_rules want to replace the event. + # we do some basic checks, and then return the replacement event and context. + + # Construct a new EventBuilder and validate it, which helps with the + # rest of these checks. + try: + builder = self.event_builder_factory.for_room_version( + original_event.room_version, third_party_result + ) + self.validator.validate_builder(builder) + except SynapseError as e: + raise Exception( + "Third party rules module created an invalid event: " + e.msg, + ) + + immutable_fields = [ + # changing the room is going to break things: we've already checked that the + # room exists, and are holding a concurrency limiter token for that room. + # Also, we might need to use a different room version. + "room_id", + # changing the type or state key might work, but we'd need to check that the + # calling functions aren't making assumptions about them. + "type", + "state_key", + ] + + for k in immutable_fields: + if getattr(builder, k, None) != original_event.get(k): + raise Exception( + "Third party rules module created an invalid event: " + "cannot change field " + k + ) + + # check that the new sender belongs to this HS + if not self.hs.is_mine_id(builder.sender): + raise Exception( + "Third party rules module created an invalid event: " + "invalid sender " + builder.sender + ) + + # copy over the original internal metadata + for k, v in original_event.internal_metadata.get_dict().items(): + setattr(builder.internal_metadata, k, v) + + # the event type hasn't changed, so there's no point in re-calculating the + # auth events. + event = await builder.build( + prev_event_ids=original_event.prev_event_ids(), + auth_event_ids=original_event.auth_event_ids(), + ) + + # we rebuild the event context, to be on the safe side. If nothing else, + # delta_ids might need an update. + context = await self.state.compute_event_context(event) + return event, context diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index d625bef4f0..89d5433259 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py
@@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl @@ -34,7 +35,8 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from synapse.config import ConfigError -from synapse.http.server import respond_with_html +from synapse.handlers._base import BaseHandler +from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart @@ -83,19 +85,15 @@ class OidcError(Exception): return self.error -class MappingException(Exception): - """Used to catch errors when mapping the UserInfo object - """ - - -class OidcHandler: +class OidcHandler(BaseHandler): """Handles requests related to the OpenID Connect login flow. """ def __init__(self, hs: "HomeServer"): - self.hs = hs + super().__init__(hs) self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] + self._user_profile_method = hs.config.oidc_user_profile_method # type: str self._client_auth = ClientAuth( hs.config.oidc_client_id, hs.config.oidc_client_secret, @@ -117,38 +115,13 @@ class OidcHandler: self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._http_client = hs.get_proxied_http_client() - self._auth_handler = hs.get_auth_handler() - self._registration_handler = hs.get_registration_handler() - self._datastore = hs.get_datastore() - self._clock = hs.get_clock() - self._hostname = hs.hostname # type: str self._server_name = hs.config.server_name # type: str self._macaroon_secret_key = hs.config.macaroon_secret_key - self._error_template = hs.config.sso_error_template # identifier for the external_ids table self._auth_provider_id = "oidc" - def _render_error( - self, request, error: str, error_description: Optional[str] = None - ) -> None: - """Render the error template and respond to the request with it. - - This is used to show errors to the user. The template of this page can - be found under `synapse/res/templates/sso_error.html`. - - Args: - request: The incoming request from the browser. - We'll respond with an HTML page describing the error. - error: A technical identifier for this error. Those include - well-known OAuth2/OIDC error types like invalid_request or - access_denied. - error_description: A human-readable description of the error. - """ - html = self._error_template.render( - error=error, error_description=error_description - ) - respond_with_html(request, 400, html) + self._sso_handler = hs.get_sso_handler() def _validate_metadata(self): """Verifies the provider metadata. @@ -196,11 +169,11 @@ class OidcHandler: % (m["response_types_supported"],) ) - # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos + # Ensure there's a userinfo endpoint to fetch from if it is required. if self._uses_userinfo: if m.get("userinfo_endpoint") is None: raise ValueError( - 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested' + 'provider has no "userinfo_endpoint", even though it is required' ) else: # If we're not using userinfo, we need a valid jwks to validate the ID token @@ -216,12 +189,14 @@ class OidcHandler: This is based on the requested scopes: if the scopes include ``openid``, the provider should give use an ID token containing the - user informations. If not, we should fetch them using the + user information. If not, we should fetch them using the ``access_token`` with the ``userinfo_endpoint``. """ - # Maybe that should be user-configurable and not inferred? - return "openid" not in self._scopes + return ( + "openid" not in self._scopes + or self._user_profile_method == "userinfo_endpoint" + ) async def load_metadata(self) -> OpenIDProviderMetadata: """Load and validate the provider metadata. @@ -423,7 +398,7 @@ class OidcHandler: return resp async def _fetch_userinfo(self, token: Token) -> UserInfo: - """Fetch user informations from the ``userinfo_endpoint``. + """Fetch user information from the ``userinfo_endpoint``. Args: token: the token given by the ``token_endpoint``. @@ -586,7 +561,7 @@ class OidcHandler: Since we might want to display OIDC-related errors in a user-friendly way, we don't raise SynapseError from here. Instead, we call - ``self._render_error`` which displays an HTML page for the error. + ``self._sso_handler.render_error`` which displays an HTML page for the error. Most of the OpenID Connect logic happens here: @@ -624,7 +599,7 @@ class OidcHandler: if error != "access_denied": logger.error("Error from the OIDC provider: %s %s", error, description) - self._render_error(request, error, description) + self._sso_handler.render_error(request, error, description) return # otherwise, it is presumably a successful response. see: @@ -634,7 +609,9 @@ class OidcHandler: session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] if session is None: logger.info("No session cookie found") - self._render_error(request, "missing_session", "No session cookie found") + self._sso_handler.render_error( + request, "missing_session", "No session cookie found" + ) return # Remove the cookie. There is a good chance that if the callback failed @@ -652,7 +629,9 @@ class OidcHandler: # Check for the state query parameter if b"state" not in request.args: logger.info("State parameter is missing") - self._render_error(request, "invalid_request", "State parameter is missing") + self._sso_handler.render_error( + request, "invalid_request", "State parameter is missing" + ) return state = request.args[b"state"][0].decode() @@ -666,17 +645,19 @@ class OidcHandler: ) = self._verify_oidc_session_token(session, state) except MacaroonDeserializationException as e: logger.exception("Invalid session") - self._render_error(request, "invalid_session", str(e)) + self._sso_handler.render_error(request, "invalid_session", str(e)) return except MacaroonInvalidSignatureException as e: logger.exception("Could not verify session") - self._render_error(request, "mismatching_session", str(e)) + self._sso_handler.render_error(request, "mismatching_session", str(e)) return # Exchange the code with the provider if b"code" not in request.args: logger.info("Code parameter is missing") - self._render_error(request, "invalid_request", "Code parameter is missing") + self._sso_handler.render_error( + request, "invalid_request", "Code parameter is missing" + ) return logger.debug("Exchanging code") @@ -685,7 +666,7 @@ class OidcHandler: token = await self._exchange_code(code) except OidcError as e: logger.exception("Could not exchange code") - self._render_error(request, e.error, e.error_description) + self._sso_handler.render_error(request, e.error, e.error_description) return logger.debug("Successfully obtained OAuth2 access token") @@ -698,7 +679,7 @@ class OidcHandler: userinfo = await self._fetch_userinfo(token) except Exception as e: logger.exception("Could not fetch userinfo") - self._render_error(request, "fetch_error", str(e)) + self._sso_handler.render_error(request, "fetch_error", str(e)) return else: logger.debug("Extracting userinfo from id_token") @@ -706,43 +687,32 @@ class OidcHandler: userinfo = await self._parse_id_token(token, nonce=nonce) except Exception as e: logger.exception("Invalid id_token") - self._render_error(request, "invalid_token", str(e)) + self._sso_handler.render_error(request, "invalid_token", str(e)) return - # Pull out the user-agent and IP from the request. - user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ - 0 - ].decode("ascii", "surrogateescape") - ip_address = self.hs.get_ip_from_request(request) + # first check if we're doing a UIA + if ui_auth_session_id: + try: + remote_user_id = self._remote_id_from_userinfo(userinfo) + except Exception as e: + logger.exception("Could not extract remote user id") + self._sso_handler.render_error(request, "mapping_error", str(e)) + return + + return await self._sso_handler.complete_sso_ui_auth_request( + self._auth_provider_id, remote_user_id, ui_auth_session_id, request + ) + + # otherwise, it's a login # Call the mapper to register/login the user try: - user_id = await self._map_userinfo_to_user( - userinfo, token, user_agent, ip_address + await self._complete_oidc_login( + userinfo, token, request, client_redirect_url ) except MappingException as e: logger.exception("Could not map user") - self._render_error(request, "mapping_error", str(e)) - return - - # Mapping providers might not have get_extra_attributes: only call this - # method if it exists. - extra_attributes = None - get_extra_attributes = getattr( - self._user_mapping_provider, "get_extra_attributes", None - ) - if get_extra_attributes: - extra_attributes = await get_extra_attributes(userinfo, token) - - # and finally complete the login - if ui_auth_session_id: - await self._auth_handler.complete_sso_ui_auth( - user_id, ui_auth_session_id, request - ) - else: - await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url, extra_attributes - ) + self._sso_handler.render_error(request, "mapping_error", str(e)) def _generate_oidc_session_token( self, @@ -771,7 +741,7 @@ class OidcHandler: Defaults to an hour. Returns: - A signed macaroon token with the session informations. + A signed macaroon token with the session information. """ macaroon = pymacaroons.Macaroon( location=self._server_name, identifier="key", key=self._macaroon_secret_key, @@ -787,7 +757,7 @@ class OidcHandler: macaroon.add_first_party_caveat( "ui_auth_session_id = %s" % (ui_auth_session_id,) ) - now = self._clock.time_msec() + now = self.clock.time_msec() expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) @@ -862,13 +832,17 @@ class OidcHandler: if not caveat.startswith(prefix): return False expiry = int(caveat[len(prefix) :]) - now = self._clock.time_msec() + now = self.clock.time_msec() return now < expiry - async def _map_userinfo_to_user( - self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str - ) -> str: - """Maps a UserInfo object to a mxid. + async def _complete_oidc_login( + self, + userinfo: UserInfo, + token: Token, + request: SynapseRequest, + client_redirect_url: str, + ) -> None: + """Given a UserInfo response, complete the login flow UserInfo should have a claim that uniquely identifies users. This claim is usually `sub`, but can be configured with `oidc_config.subject_claim`. @@ -880,93 +854,118 @@ class OidcHandler: If a user already exists with the mxid we've mapped and allow_existing_users is disabled, raise an exception. + Otherwise, render a redirect back to the client_redirect_url with a loginToken. + Args: userinfo: an object representing the user token: a dict with the tokens obtained from the provider - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. + request: The request to respond to + client_redirect_url: The redirect URL passed in by the client. Raises: MappingException: if there was an error while mapping some properties - - Returns: - The mxid of the user """ try: - remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo) + remote_user_id = self._remote_id_from_userinfo(userinfo) except Exception as e: raise MappingException( "Failed to extract subject from OIDC response: %s" % (e,) ) - # Some OIDC providers use integer IDs, but Synapse expects external IDs - # to be strings. - remote_user_id = str(remote_user_id) - logger.info( - "Looking for existing mapping for user %s:%s", - self._auth_provider_id, - remote_user_id, - ) - - registered_user_id = await self._datastore.get_user_by_external_id( - self._auth_provider_id, remote_user_id, + # Older mapping providers don't accept the `failures` argument, so we + # try and detect support. + mapper_signature = inspect.signature( + self._user_mapping_provider.map_user_attributes ) + supports_failures = "failures" in mapper_signature.parameters - if registered_user_id is not None: - logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id - - try: - attributes = await self._user_mapping_provider.map_user_attributes( - userinfo, token - ) - except Exception as e: - raise MappingException( - "Could not extract user attributes from OIDC response: " + str(e) - ) + async def oidc_response_to_user_attributes(failures: int) -> UserAttributes: + """ + Call the mapping provider to map the OIDC userinfo and token to user attributes. - logger.debug( - "Retrieved user attributes from user mapping provider: %r", attributes - ) + This is backwards compatibility for abstraction for the SSO handler. + """ + if supports_failures: + attributes = await self._user_mapping_provider.map_user_attributes( + userinfo, token, failures + ) + else: + # If the mapping provider does not support processing failures, + # do not continually generate the same Matrix ID since it will + # continue to already be in use. Note that the error raised is + # arbitrary and will get turned into a MappingException. + if failures: + raise MappingException( + "Mapping provider does not support de-duplicating Matrix IDs" + ) - if not attributes["localpart"]: - raise MappingException("localpart is empty") + attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore + userinfo, token + ) - localpart = map_username_to_mxid_localpart(attributes["localpart"]) + return UserAttributes(**attributes) - user_id = UserID(localpart, self._hostname).to_string() - users = await self._datastore.get_users_by_id_case_insensitive(user_id) - if users: + async def grandfather_existing_users() -> Optional[str]: if self._allow_existing_users: - if len(users) == 1: - registered_user_id = next(iter(users)) - elif user_id in users: - registered_user_id = user_id - else: - raise MappingException( - "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( - user_id, list(users.keys()) + # If allowing existing users we want to generate a single localpart + # and attempt to match it. + attributes = await oidc_response_to_user_attributes(failures=0) + + user_id = UserID(attributes.localpart, self.server_name).to_string() + users = await self.store.get_users_by_id_case_insensitive(user_id) + if users: + # If an existing matrix ID is returned, then use it. + if len(users) == 1: + previously_registered_user_id = next(iter(users)) + elif user_id in users: + previously_registered_user_id = user_id + else: + # Do not attempt to continue generating Matrix IDs. + raise MappingException( + "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( + user_id, users + ) ) - ) - else: - # This mxid is taken - raise MappingException("mxid '{}' is already taken".format(user_id)) - else: - # It's the first time this user is logging in and the mapped mxid was - # not taken, register the user - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=attributes["display_name"], - user_agent_ips=(user_agent, ip_address), - ) - await self._datastore.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id, + + return previously_registered_user_id + + return None + + # Mapping providers might not have get_extra_attributes: only call this + # method if it exists. + extra_attributes = None + get_extra_attributes = getattr( + self._user_mapping_provider, "get_extra_attributes", None + ) + if get_extra_attributes: + extra_attributes = await get_extra_attributes(userinfo, token) + + await self._sso_handler.complete_sso_login_request( + self._auth_provider_id, + remote_user_id, + request, + client_redirect_url, + oidc_response_to_user_attributes, + grandfather_existing_users, + extra_attributes, ) - return registered_user_id + def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: + """Extract the unique remote id from an OIDC UserInfo block -UserAttribute = TypedDict( - "UserAttribute", {"localpart": str, "display_name": Optional[str]} + Args: + userinfo: An object representing the user given by the OIDC provider + Returns: + remote user id + """ + remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo) + # Some OIDC providers use integer IDs, but Synapse expects external IDs + # to be strings. + return str(remote_user_id) + + +UserAttributeDict = TypedDict( + "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]} ) C = TypeVar("C") @@ -1009,13 +1008,15 @@ class OidcMappingProvider(Generic[C]): raise NotImplementedError() async def map_user_attributes( - self, userinfo: UserInfo, token: Token - ) -> UserAttribute: + self, userinfo: UserInfo, token: Token, failures: int + ) -> UserAttributeDict: """Map a `UserInfo` object into user attributes. Args: userinfo: An object representing the user given by the OIDC provider token: A dict with the tokens returned by the provider + failures: How many times a call to this function with this + UserInfo has resulted in a failure. Returns: A dict containing the ``localpart`` and (optionally) the ``display_name`` @@ -1045,10 +1046,10 @@ env = Environment(finalize=jinja_finalize) @attr.s class JinjaOidcMappingConfig: - subject_claim = attr.ib() # type: str - localpart_template = attr.ib() # type: Template - display_name_template = attr.ib() # type: Optional[Template] - extra_attributes = attr.ib() # type: Dict[str, Template] + subject_claim = attr.ib(type=str) + localpart_template = attr.ib(type=Optional[Template]) + display_name_template = attr.ib(type=Optional[Template]) + extra_attributes = attr.ib(type=Dict[str, Template]) class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @@ -1064,18 +1065,14 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") - if "localpart_template" not in config: - raise ConfigError( - "missing key: oidc_config.user_mapping_provider.config.localpart_template" - ) - - try: - localpart_template = env.from_string(config["localpart_template"]) - except Exception as e: - raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r" - % (e,) - ) + localpart_template = None # type: Optional[Template] + if "localpart_template" in config: + try: + localpart_template = env.from_string(config["localpart_template"]) + except Exception as e: + raise ConfigError( + "invalid jinja template", path=["localpart_template"] + ) from e display_name_template = None # type: Optional[Template] if "display_name_template" in config: @@ -1083,26 +1080,22 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): display_name_template = env.from_string(config["display_name_template"]) except Exception as e: raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r" - % (e,) - ) + "invalid jinja template", path=["display_name_template"] + ) from e extra_attributes = {} # type Dict[str, Template] if "extra_attributes" in config: extra_attributes_config = config.get("extra_attributes") or {} if not isinstance(extra_attributes_config, dict): - raise ConfigError( - "oidc_config.user_mapping_provider.config.extra_attributes must be a dict" - ) + raise ConfigError("must be a dict", path=["extra_attributes"]) for key, value in extra_attributes_config.items(): try: extra_attributes[key] = env.from_string(value) except Exception as e: raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r" - % (key, e) - ) + "invalid jinja template", path=["extra_attributes", key] + ) from e return JinjaOidcMappingConfig( subject_claim=subject_claim, @@ -1115,9 +1108,19 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): return userinfo[self._config.subject_claim] async def map_user_attributes( - self, userinfo: UserInfo, token: Token - ) -> UserAttribute: - localpart = self._config.localpart_template.render(user=userinfo).strip() + self, userinfo: UserInfo, token: Token, failures: int + ) -> UserAttributeDict: + localpart = None + + if self._config.localpart_template: + localpart = self._config.localpart_template.render(user=userinfo).strip() + + # Ensure only valid characters are included in the MXID. + localpart = map_username_to_mxid_localpart(localpart) + + # Append suffix integer if last call to this function failed to produce + # a usable mxid. + localpart += str(failures) if failures else "" display_name = None # type: Optional[str] if self._config.display_name_template is not None: @@ -1128,7 +1131,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if display_name == "": display_name = None - return UserAttribute(localpart=localpart, display_name=display_name) + return UserAttributeDict(localpart=localpart, display_name=display_name) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: extras = {} # type: Dict[str, str] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 2c2a633938..5372753707 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py
@@ -92,7 +92,7 @@ class PaginationHandler: self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max - if hs.config.retention_enabled: + if hs.config.run_background_tasks and hs.config.retention_enabled: # Run the purge jobs described in the configuration file. for job in hs.config.retention_purge_jobs: logger.info("Setting up purge job with config: %s", job) @@ -299,17 +299,22 @@ class PaginationHandler: """ return self._purges_by_id.get(purge_id) - async def purge_room(self, room_id: str) -> None: - """Purge the given room from the database""" + async def purge_room(self, room_id: str, force: bool = False) -> None: + """Purge the given room from the database. + + Args: + room_id: room to be purged + force: set true to skip checking for joined users. + """ with await self.pagination_lock.write(room_id): # check we know about the room await self.store.get_room_version_id(room_id) # first check that we have no users in this room - joined = await self.store.is_host_joined(room_id, self._server_name) - - if joined: - raise SynapseError(400, "Users are still joined to this room") + if not force: + joined = await self.store.is_host_joined(room_id, self._server_name) + if joined: + raise SynapseError(400, "Users are still joined to this room") await self.storage.purge_events.purge_room(room_id) @@ -383,7 +388,7 @@ class PaginationHandler: "room_key", leave_token ) - await self.hs.get_handlers().federation_handler.maybe_backfill( + await self.hs.get_federation_handler().maybe_backfill( room_id, curr_topo, limit=pagin_config.limit, ) diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
index 88e2f87200..6c635cc31b 100644 --- a/synapse/handlers/password_policy.py +++ b/synapse/handlers/password_policy.py
@@ -16,14 +16,18 @@ import logging import re +from typing import TYPE_CHECKING from synapse.api.errors import Codes, PasswordRefusedError +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class PasswordPolicyHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.policy = hs.config.password_policy self.enabled = hs.config.password_policy_enabled @@ -33,11 +37,11 @@ class PasswordPolicyHandler: self.regexp_uppercase = re.compile("[A-Z]") self.regexp_lowercase = re.compile("[a-z]") - def validate_password(self, password): + def validate_password(self, password: str) -> None: """Checks whether a given password complies with the server's policy. Args: - password (str): The password to check against the server's policy. + password: The password to check against the server's policy. Raises: PasswordRefusedError: The password doesn't comply with the server's policy. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 1000ac95ff..22d1e9d35c 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -25,7 +25,7 @@ The methods that define policy are: import abc import logging from contextlib import contextmanager -from typing import Dict, Iterable, List, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple from prometheus_client import Counter from typing_extensions import ContextManager @@ -46,9 +46,8 @@ from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer -MYPY = False -if MYPY: - import synapse.server +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -101,7 +100,7 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class BasePresenceHandler(abc.ABC): """Parts of the PresenceHandler that are shared between workers and master""" - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -199,7 +198,7 @@ class BasePresenceHandler(abc.ABC): class PresenceHandler(BasePresenceHandler): - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.is_mine_id = hs.is_mine_id @@ -802,7 +801,7 @@ class PresenceHandler(BasePresenceHandler): between the requested tokens due to the limit. The token returned can be used in a subsequent call to this - function to get further updatees. + function to get further updates. The updates are a list of 2-tuples of stream ID and the row data """ @@ -977,7 +976,7 @@ def should_notify(old_state, new_state): new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY ): - # Only notify about last active bumps if we're not currently acive + # Only notify about last active bumps if we're not currently active if not new_state.currently_active: notify_reason_counter.labels("last_active_change_online").inc() return True @@ -1011,7 +1010,7 @@ def format_user_presence_state(state, now, include_user_id=True): class PresenceEventSource: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): # We can't call get_presence_handler here because there's a cycle: # # Presence -> Notifier -> PresenceEventSource -> Presence @@ -1071,12 +1070,14 @@ class PresenceEventSource: users_interested_in = await self._get_interested_in(user, explicit_room_id) - user_ids_changed = set() + user_ids_changed = set() # type: Collection[str] changed = None if from_key: changed = stream_change_cache.get_all_entities_changed(from_key) if changed is not None and len(changed) < 500: + assert isinstance(user_ids_changed, set) + # For small deltas, its quicker to get all changes and then # work out if we share a room or they're in our presence list get_updates_counter.labels("stream").inc() diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 5453e6dfc8..dee0ef45e7 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -12,9 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import random +from typing import TYPE_CHECKING, Optional from synapse.api.errors import ( AuthError, @@ -24,26 +24,37 @@ from synapse.api.errors import ( StoreError, SynapseError, ) -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID, create_requester, get_domain_from_id +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.types import ( + JsonDict, + Requester, + UserID, + create_requester, + get_domain_from_id, +) from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) MAX_DISPLAYNAME_LEN = 256 MAX_AVATAR_URL_LEN = 1000 -class BaseProfileHandler(BaseHandler): +class ProfileHandler(BaseHandler): """Handles fetching and updating user profile information. - BaseProfileHandler can be instantiated directly on workers and will - delegate to master when necessary. The master process should use the - subclass MasterProfileHandler + ProfileHandler can be instantiated directly on workers and will + delegate to master when necessary. """ - def __init__(self, hs): + PROFILE_UPDATE_MS = 60 * 1000 + PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 + + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.federation = hs.get_federation_client() @@ -53,7 +64,12 @@ class BaseProfileHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() - async def get_profile(self, user_id): + if hs.config.run_background_tasks: + self.clock.looping_call( + self._update_remote_profile_cache, self.PROFILE_UPDATE_MS + ) + + async def get_profile(self, user_id: str) -> JsonDict: target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): @@ -82,11 +98,18 @@ class BaseProfileHandler(BaseHandler): except RequestSendFailed as e: raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: + if e.code < 500 and e.code != 404: + # Other codes are not allowed in c2s API + logger.info( + "Server replied with wrong response: %s %s", e.code, e.msg + ) + + raise SynapseError(502, "Failed to fetch profile") raise e.to_synapse_error() - async def get_profile_from_cache(self, user_id): + async def get_profile_from_cache(self, user_id: str) -> JsonDict: """Get the profile information from our local cache. If the user is - ours then the profile information will always be corect. Otherwise, + ours then the profile information will always be correct. Otherwise, it may be out of date/missing. """ target_user = UserID.from_string(user_id) @@ -108,7 +131,7 @@ class BaseProfileHandler(BaseHandler): profile = await self.store.get_from_remote_profile_cache(user_id) return profile or {} - async def get_displayname(self, target_user): + async def get_displayname(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: displayname = await self.store.get_profile_displayname( @@ -136,15 +159,19 @@ class BaseProfileHandler(BaseHandler): return result["displayname"] async def set_displayname( - self, target_user, requester, new_displayname, by_admin=False - ): + self, + target_user: UserID, + requester: Requester, + new_displayname: str, + by_admin: bool = False, + ) -> None: """Set the displayname of a user Args: - target_user (UserID): the user whose displayname is to be changed. - requester (Requester): The user attempting to make this change. - new_displayname (str): The displayname to give this user. - by_admin (bool): Whether this change was made by an administrator. + target_user: the user whose displayname is to be changed. + requester: The user attempting to make this change. + new_displayname: The displayname to give this user. + by_admin: Whether this change was made by an administrator. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -162,23 +189,30 @@ class BaseProfileHandler(BaseHandler): ) if not isinstance(new_displayname, str): - raise SynapseError(400, "Invalid displayname") + raise SynapseError( + 400, "'displayname' must be a string", errcode=Codes.INVALID_PARAM + ) if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) + displayname_to_set = new_displayname # type: Optional[str] if new_displayname == "": - new_displayname = None + displayname_to_set = None # If the admin changes the display name of a user, the requesting user cannot send # the join event to update the displayname in the rooms. # This must be done by the target user himself. if by_admin: - requester = create_requester(target_user) + requester = create_requester( + target_user, authenticated_entity=requester.authenticated_entity, + ) - await self.store.set_profile_displayname(target_user.localpart, new_displayname) + await self.store.set_profile_displayname( + target_user.localpart, displayname_to_set + ) if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(target_user.localpart) @@ -188,7 +222,7 @@ class BaseProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - async def get_avatar_url(self, target_user): + async def get_avatar_url(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: avatar_url = await self.store.get_profile_avatar_url( @@ -215,15 +249,19 @@ class BaseProfileHandler(BaseHandler): return result["avatar_url"] async def set_avatar_url( - self, target_user, requester, new_avatar_url, by_admin=False + self, + target_user: UserID, + requester: Requester, + new_avatar_url: str, + by_admin: bool = False, ): """Set a new avatar URL for a user. Args: - target_user (UserID): the user whose avatar URL is to be changed. - requester (Requester): The user attempting to make this change. - new_avatar_url (str): The avatar URL to give this user. - by_admin (bool): Whether this change was made by an administrator. + target_user: the user whose avatar URL is to be changed. + requester: The user attempting to make this change. + new_avatar_url: The avatar URL to give this user. + by_admin: Whether this change was made by an administrator. """ if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -239,7 +277,9 @@ class BaseProfileHandler(BaseHandler): ) if not isinstance(new_avatar_url, str): - raise SynapseError(400, "Invalid displayname") + raise SynapseError( + 400, "'avatar_url' must be a string", errcode=Codes.INVALID_PARAM + ) if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( @@ -248,7 +288,9 @@ class BaseProfileHandler(BaseHandler): # Same like set_displayname if by_admin: - requester = create_requester(target_user) + requester = create_requester( + target_user, authenticated_entity=requester.authenticated_entity + ) await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) @@ -260,7 +302,7 @@ class BaseProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - async def on_profile_query(self, args): + async def on_profile_query(self, args: JsonDict) -> JsonDict: user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") @@ -285,7 +327,9 @@ class BaseProfileHandler(BaseHandler): return response - async def _update_join_states(self, requester, target_user): + async def _update_join_states( + self, requester: Requester, target_user: UserID + ) -> None: if not self.hs.is_mine(target_user): return @@ -316,15 +360,17 @@ class BaseProfileHandler(BaseHandler): "Failed to update join event for room %s - %s", room_id, str(e) ) - async def check_profile_query_allowed(self, target_user, requester=None): + async def check_profile_query_allowed( + self, target_user: UserID, requester: Optional[UserID] = None + ) -> None: """Checks whether a profile query is allowed. If the 'require_auth_for_profile_requests' config flag is set to True and a 'requester' is provided, the query is only allowed if the two users share a room. Args: - target_user (UserID): The owner of the queried profile. - requester (None|UserID): The user querying for the profile. + target_user: The owner of the queried profile. + requester: The user querying for the profile. Raises: SynapseError(403): The two users share no room, or ne user couldn't @@ -363,25 +409,7 @@ class BaseProfileHandler(BaseHandler): raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN) raise - -class MasterProfileHandler(BaseProfileHandler): - PROFILE_UPDATE_MS = 60 * 1000 - PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 - - def __init__(self, hs): - super().__init__(hs) - - assert hs.config.worker_app is None - - self.clock.looping_call( - self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS - ) - - def _start_update_remote_profile_cache(self): - return run_as_background_process( - "Update remote profile", self._update_remote_profile_cache - ) - + @wrap_as_background_process("Update remote profile") async def _update_remote_profile_cache(self): """Called periodically to check profiles of remote users we haven't checked in a while. diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index c32f314a1c..a7550806e6 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py
@@ -14,23 +14,29 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.util.async_helpers import Linearizer from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class ReadMarkerHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() self.read_marker_linearizer = Linearizer(name="read_marker") self.notifier = hs.get_notifier() - async def received_client_read_marker(self, room_id, user_id, event_id): + async def received_client_read_marker( + self, room_id: str, user_id: str, event_id: str + ) -> None: """Updates the read marker for a given user in a given room if the event ID given is ahead in the stream relative to the current read marker. diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 7225923757..a9abdf42e0 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py
@@ -13,16 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, List, Optional, Tuple +from synapse.appservice import ApplicationService from synapse.handlers._base import BaseHandler -from synapse.types import ReadReceipt, get_domain_from_id -from synapse.util.async_helpers import maybe_awaitable +from synapse.types import JsonDict, ReadReceipt, get_domain_from_id + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) class ReceiptsHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.server_name = hs.config.server_name @@ -35,7 +39,7 @@ class ReceiptsHandler(BaseHandler): self.clock = self.hs.get_clock() self.state = hs.get_state_handler() - async def _received_remote_receipt(self, origin, content): + async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: """Called when we receive an EDU of type m.receipt from a remote HS. """ receipts = [] @@ -62,11 +66,11 @@ class ReceiptsHandler(BaseHandler): await self._handle_new_receipts(receipts) - async def _handle_new_receipts(self, receipts): + async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: """Takes a list of receipts, stores them and informs the notifier. """ - min_batch_id = None - max_batch_id = None + min_batch_id = None # type: Optional[int] + max_batch_id = None # type: Optional[int] for receipt in receipts: res = await self.store.insert_receipt( @@ -88,7 +92,8 @@ class ReceiptsHandler(BaseHandler): if max_batch_id is None or max_persisted_id > max_batch_id: max_batch_id = max_persisted_id - if min_batch_id is None: + # Either both of these should be None or neither. + if min_batch_id is None or max_batch_id is None: # no new receipts return False @@ -96,15 +101,15 @@ class ReceiptsHandler(BaseHandler): self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. - await maybe_awaitable( - self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids - ) + await self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids ) return True - async def received_client_receipt(self, room_id, receipt_type, user_id, event_id): + async def received_client_receipt( + self, room_id: str, receipt_type: str, user_id: str, event_id: str + ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. """ @@ -124,10 +129,12 @@ class ReceiptsHandler(BaseHandler): class ReceiptEventSource: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - async def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events( + self, from_key: int, room_ids: List[str], **kwargs + ) -> Tuple[List[JsonDict], int]: from_key = int(from_key) to_key = self.get_current_key() @@ -140,5 +147,37 @@ class ReceiptEventSource: return (events, to_key) - def get_current_key(self, direction="f"): + async def get_new_events_as( + self, from_key: int, service: ApplicationService + ) -> Tuple[List[JsonDict], int]: + """Returns a set of new receipt events that an appservice + may be interested in. + + Args: + from_key: the stream position at which events should be fetched from + service: The appservice which may be interested + """ + from_key = int(from_key) + to_key = self.get_current_key() + + if from_key == to_key: + return [], to_key + + # Fetch all read receipts for all rooms, up to a limit of 100. This is ordered + # by most recent. + rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms( + from_key=from_key, to_key=to_key + ) + + # Then filter down to rooms that the AS can read + events = [] + for room_id, event in rooms_to_events.items(): + if not await service.matches_user_in_member_list(room_id, self.store): + continue + + events.append(event) + + return (events, to_key) + + def get_current_key(self, direction: str = "f") -> int: return self.store.get_max_receipt_stream_id() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 538f4b2a61..a2cf0f6f3e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -15,10 +15,12 @@ """Contains functions for registering clients.""" import logging +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet @@ -32,26 +34,25 @@ from synapse.types import RoomAlias, UserID, create_requester from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class RegistrationHandler(BaseHandler): - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() - self.identity_handler = self.hs.get_handlers().identity_handler + self.identity_handler = self.hs.get_identity_handler() self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() self._server_notices_mxid = hs.config.server_notices_mxid + self._server_name = hs.hostname self.spam_checker = hs.get_spam_checker() @@ -70,7 +71,10 @@ class RegistrationHandler(BaseHandler): self.session_lifetime = hs.config.session_lifetime async def check_username( - self, localpart, guest_access_token=None, assigned_user_id=None + self, + localpart: str, + guest_access_token: Optional[str] = None, + assigned_user_id: Optional[str] = None, ): if types.contains_invalid_mxid_characters(localpart): raise SynapseError( @@ -115,7 +119,10 @@ class RegistrationHandler(BaseHandler): 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) user_data = await self.auth.get_user_by_access_token(guest_access_token) - if not user_data["is_guest"] or user_data["user"].localpart != localpart: + if ( + not user_data.is_guest + or UserID.from_string(user_data.user_id).localpart != localpart + ): raise AuthError( 403, "Cannot register taken user ID without valid guest " @@ -136,45 +143,51 @@ class RegistrationHandler(BaseHandler): async def register_user( self, - localpart=None, - password_hash=None, - guest_access_token=None, - make_guest=False, - admin=False, - threepid=None, - user_type=None, - default_display_name=None, - address=None, - bind_emails=[], - by_admin=False, - user_agent_ips=None, - ): + localpart: Optional[str] = None, + password_hash: Optional[str] = None, + guest_access_token: Optional[str] = None, + make_guest: bool = False, + admin: bool = False, + threepid: Optional[dict] = None, + user_type: Optional[str] = None, + default_display_name: Optional[str] = None, + address: Optional[str] = None, + bind_emails: List[str] = [], + by_admin: bool = False, + user_agent_ips: Optional[List[Tuple[str, str]]] = None, + ) -> str: """Registers a new client on the server. Args: localpart: The local part of the user ID to register. If None, one will be generated. - password_hash (str|None): The hashed password to assign to this user so they can + password_hash: The hashed password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). - user_type (str|None): type of user. One of the values from + guest_access_token: The access token used when this was a guest + account. + make_guest: True if the the new user should be guest, + false to add a regular user account. + admin: True if the user should be registered as a server admin. + threepid: The threepid used for registering, if any. + user_type: type of user. One of the values from api.constants.UserTypes, or None for a normal user. - default_display_name (unicode|None): if set, the new user's displayname + default_display_name: if set, the new user's displayname will be set to this. Defaults to 'localpart'. - address (str|None): the IP address used to perform the registration. - bind_emails (List[str]): list of emails to bind to this account. - by_admin (bool): True if this registration is being made via the + address: the IP address used to perform the registration. + bind_emails: list of emails to bind to this account. + by_admin: True if this registration is being made via the admin api, otherwise False. - user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used + user_agent_ips: Tuples of IP addresses and user-agents used during the registration process. Returns: - str: user_id + The registere user_id. Raises: SynapseError if there was a problem registering. """ self.check_registration_ratelimit(address) - result = self.spam_checker.check_registration_for_spam( + result = await self.spam_checker.check_registration_for_spam( threepid, localpart, user_agent_ips or [], ) @@ -232,8 +245,10 @@ class RegistrationHandler(BaseHandler): else: # autogen a sequential user ID fail_count = 0 - user = None - while not user: + # If a default display name is not given, generate one. + generate_display_name = default_display_name is None + # This breaks on successful registration *or* errors after 10 failures. + while True: # Fail after being unable to find a suitable ID a few times if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") @@ -242,7 +257,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) - if default_display_name is None: + if generate_display_name: default_display_name = localpart try: await self.register_with_store( @@ -258,8 +273,6 @@ class RegistrationHandler(BaseHandler): break except SynapseError: # if user id is taken, just generate another - user = None - user_id = None fail_count += 1 if not self.hs.config.user_consent_at_registration: @@ -291,7 +304,7 @@ class RegistrationHandler(BaseHandler): return user_id - async def _create_and_join_rooms(self, user_id: str): + async def _create_and_join_rooms(self, user_id: str) -> None: """ Create the auto-join rooms and join or invite the user to them. @@ -314,7 +327,8 @@ class RegistrationHandler(BaseHandler): requires_join = False if self.hs.config.registration.auto_join_user_id: fake_requester = create_requester( - self.hs.config.registration.auto_join_user_id + self.hs.config.registration.auto_join_user_id, + authenticated_entity=self._server_name, ) # If the room requires an invite, add the user to the list of invites. @@ -326,7 +340,9 @@ class RegistrationHandler(BaseHandler): # being necessary this will occur after the invite was sent. requires_join = True else: - fake_requester = create_requester(user_id) + fake_requester = create_requester( + user_id, authenticated_entity=self._server_name + ) # Choose whether to federate the new room. if not self.hs.config.registration.autocreate_auto_join_rooms_federated: @@ -359,7 +375,9 @@ class RegistrationHandler(BaseHandler): # created it, then ensure the first user joins it. if requires_join: await room_member_handler.update_membership( - requester=create_requester(user_id), + requester=create_requester( + user_id, authenticated_entity=self._server_name + ), target=UserID.from_string(user_id), room_id=info["room_id"], # Since it was just created, there are no remote hosts. @@ -367,15 +385,10 @@ class RegistrationHandler(BaseHandler): action="join", ratelimit=False, ) - - except ConsentNotGivenError as e: - # Technically not necessary to pull out this error though - # moving away from bare excepts is a good thing to do. - logger.error("Failed to join new user to %r: %r", r, e) except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - async def _join_rooms(self, user_id: str): + async def _join_rooms(self, user_id: str) -> None: """ Join or invite the user to the auto-join rooms. @@ -421,9 +434,13 @@ class RegistrationHandler(BaseHandler): # Send the invite, if necessary. if requires_invite: + # If an invite is required, there must be a auto-join user ID. + assert self.hs.config.registration.auto_join_user_id + await room_member_handler.update_membership( requester=create_requester( - self.hs.config.registration.auto_join_user_id + self.hs.config.registration.auto_join_user_id, + authenticated_entity=self._server_name, ), target=UserID.from_string(user_id), room_id=room_id, @@ -434,7 +451,9 @@ class RegistrationHandler(BaseHandler): # Send the join. await room_member_handler.update_membership( - requester=create_requester(user_id), + requester=create_requester( + user_id, authenticated_entity=self._server_name + ), target=UserID.from_string(user_id), room_id=room_id, remote_room_hosts=remote_room_hosts, @@ -449,7 +468,7 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - async def _auto_join_rooms(self, user_id: str): + async def _auto_join_rooms(self, user_id: str) -> None: """Automatically joins users to auto join rooms - creating the room in the first place if the user is the first to be created. @@ -472,16 +491,16 @@ class RegistrationHandler(BaseHandler): else: await self._join_rooms(user_id) - async def post_consent_actions(self, user_id): + async def post_consent_actions(self, user_id: str) -> None: """A series of registration actions that can only be carried out once consent has been granted Args: - user_id (str): The user to join + user_id: The user to join """ await self._auto_join_rooms(user_id) - async def appservice_register(self, user_localpart, as_token): + async def appservice_register(self, user_localpart: str, as_token: str) -> str: user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) @@ -506,7 +525,9 @@ class RegistrationHandler(BaseHandler): ) return user_id - def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): + def check_user_id_not_appservice_exclusive( + self, user_id: str, allowed_appservice: Optional[ApplicationService] = None + ) -> None: # don't allow people to register the server notices mxid if self._server_notices_mxid is not None: if user_id == self._server_notices_mxid: @@ -530,12 +551,12 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - def check_registration_ratelimit(self, address): + def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address Args: - address (str|None): the IP address used to perform the registration. If this is + address: the IP address used to perform the registration. If this is None, no ratelimiting will be performed. Raises: @@ -546,42 +567,39 @@ class RegistrationHandler(BaseHandler): self.ratelimiter.ratelimit(address) - def register_with_store( + async def register_with_store( self, - user_id, - password_hash=None, - was_guest=False, - make_guest=False, - appservice_id=None, - create_profile_with_displayname=None, - admin=False, - user_type=None, - address=None, - shadow_banned=False, - ): + user_id: str, + password_hash: Optional[str] = None, + was_guest: bool = False, + make_guest: bool = False, + appservice_id: Optional[str] = None, + create_profile_with_displayname: Optional[str] = None, + admin: bool = False, + user_type: Optional[str] = None, + address: Optional[str] = None, + shadow_banned: bool = False, + ) -> None: """Register user in the datastore. Args: - user_id (str): The desired user ID to register. - password_hash (str|None): Optional. The password hash for this user. - was_guest (bool): Optional. Whether this is a guest account being + user_id: The desired user ID to register. + password_hash: Optional. The password hash for this user. + was_guest: Optional. Whether this is a guest account being upgraded to a non-guest account. - make_guest (boolean): True if the the new user should be guest, + make_guest: True if the the new user should be guest, false to add a regular user account. - appservice_id (str|None): The ID of the appservice registering the user. - create_profile_with_displayname (unicode|None): Optionally create a + appservice_id: The ID of the appservice registering the user. + create_profile_with_displayname: Optionally create a profile for the user, setting their displayname to the given value - admin (boolean): is an admin user? - user_type (str|None): type of user. One of the values from + admin: is an admin user? + user_type: type of user. One of the values from api.constants.UserTypes, or None for a normal user. - address (str|None): the IP address used to perform the registration. - shadow_banned (bool): Whether to shadow-ban the user - - Returns: - Awaitable + address: the IP address used to perform the registration. + shadow_banned: Whether to shadow-ban the user """ if self.hs.config.worker_app: - return self._register_client( + await self._register_client( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -594,7 +612,7 @@ class RegistrationHandler(BaseHandler): shadow_banned=shadow_banned, ) else: - return self.store.register_user( + await self.store.register_user( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -607,22 +625,25 @@ class RegistrationHandler(BaseHandler): ) async def register_device( - self, user_id, device_id, initial_display_name, is_guest=False - ): + self, + user_id: str, + device_id: Optional[str], + initial_display_name: Optional[str], + is_guest: bool = False, + is_appservice_ghost: bool = False, + ) -> Tuple[str, str]: """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. Args: - user_id (str): full canonical @user:id - device_id (str|None): The device ID to check, or None to generate - a new one. - initial_display_name (str|None): An optional display name for the - device. - is_guest (bool): Whether this is a guest account + user_id: full canonical @user:id + device_id: The device ID to check, or None to generate a new one. + initial_display_name: An optional display name for the device. + is_guest: Whether this is a guest account Returns: - tuple[str, str]: Tuple of device ID and access token + Tuple of device ID and access token """ if self.hs.config.worker_app: @@ -631,6 +652,7 @@ class RegistrationHandler(BaseHandler): device_id=device_id, initial_display_name=initial_display_name, is_guest=is_guest, + is_appservice_ghost=is_appservice_ghost, ) return r["device_id"], r["access_token"] @@ -642,7 +664,7 @@ class RegistrationHandler(BaseHandler): ) valid_until_ms = self.clock.time_msec() + self.session_lifetime - device_id = await self.device_handler.check_device_registered( + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) if is_guest: @@ -652,20 +674,24 @@ class RegistrationHandler(BaseHandler): ) else: access_token = await self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, valid_until_ms=valid_until_ms + user_id, + device_id=registered_device_id, + valid_until_ms=valid_until_ms, + is_appservice_ghost=is_appservice_ghost, ) - return (device_id, access_token) + return (registered_device_id, access_token) - async def post_registration_actions(self, user_id, auth_result, access_token): + async def post_registration_actions( + self, user_id: str, auth_result: dict, access_token: Optional[str] + ) -> None: """A user has completed registration Args: - user_id (str): The user ID that consented - auth_result (dict): The authenticated credentials of the newly - registered user. - access_token (str|None): The access token of the newly logged in - device, or None if `inhibit_login` enabled. + user_id: The user ID that consented + auth_result: The authenticated credentials of the newly registered user. + access_token: The access token of the newly logged in device, or + None if `inhibit_login` enabled. """ if self.hs.config.worker_app: await self._post_registration_client( @@ -691,19 +717,20 @@ class RegistrationHandler(BaseHandler): if auth_result and LoginType.TERMS in auth_result: await self._on_user_consented(user_id, self.hs.config.user_consent_version) - async def _on_user_consented(self, user_id, consent_version): + async def _on_user_consented(self, user_id: str, consent_version: str) -> None: """A user consented to the terms on registration Args: - user_id (str): The user ID that consented. - consent_version (str): version of the policy the user has - consented to. + user_id: The user ID that consented. + consent_version: version of the policy the user has consented to. """ logger.info("%s has consented to the privacy policy", user_id) await self.store.user_set_consent_version(user_id, consent_version) await self.post_consent_actions(user_id) - async def _register_email_threepid(self, user_id, threepid, token): + async def _register_email_threepid( + self, user_id: str, threepid: dict, token: Optional[str] + ) -> None: """Add an email address as a 3pid identifier Also adds an email pusher for the email address, if configured in the @@ -712,10 +739,9 @@ class RegistrationHandler(BaseHandler): Must be called on master. Args: - user_id (str): id of user - threepid (object): m.login.email.identity auth response - token (str|None): access_token for the user, or None if not logged - in. + user_id: id of user + threepid: m.login.email.identity auth response + token: access_token for the user, or None if not logged in. """ reqd = ("medium", "address", "validated_at") if any(x not in threepid for x in reqd): @@ -741,7 +767,9 @@ class RegistrationHandler(BaseHandler): # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. user_tuple = await self.store.get_user_by_access_token(token) - token_id = user_tuple["token_id"] + # The token better still exist. + assert user_tuple + token_id = user_tuple.token_id await self.pusher_pool.add_pusher( user_id=user_id, @@ -755,14 +783,14 @@ class RegistrationHandler(BaseHandler): data={}, ) - async def _register_msisdn_threepid(self, user_id, threepid): + async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None: """Add a phone number as a 3pid identifier Must be called on master. Args: - user_id (str): id of user - threepid (object): m.login.msisdn auth response + user_id: id of user + threepid: m.login.msisdn auth response """ try: assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index d5f7c78edf..1f809fa161 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple from synapse.api.constants import ( EventTypes, + HistoryVisibility, JoinRules, Membership, RoomCreationPreset, @@ -81,21 +82,21 @@ class RoomCreationHandler(BaseHandler): self._presets_dict = { RoomCreationPreset.PRIVATE_CHAT: { "join_rules": JoinRules.INVITE, - "history_visibility": "shared", + "history_visibility": HistoryVisibility.SHARED, "original_invitees_have_ops": False, "guest_can_join": True, "power_level_content_override": {"invite": 0}, }, RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { "join_rules": JoinRules.INVITE, - "history_visibility": "shared", + "history_visibility": HistoryVisibility.SHARED, "original_invitees_have_ops": True, "guest_can_join": True, "power_level_content_override": {"invite": 0}, }, RoomCreationPreset.PUBLIC_CHAT: { "join_rules": JoinRules.PUBLIC, - "history_visibility": "shared", + "history_visibility": HistoryVisibility.SHARED, "original_invitees_have_ops": False, "guest_can_join": False, "power_level_content_override": {}, @@ -120,7 +121,7 @@ class RoomCreationHandler(BaseHandler): # subsequent requests self._upgrade_response_cache = ResponseCache( hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS - ) + ) # type: ResponseCache[Tuple[str, str]] self._server_notices_mxid = hs.config.server_notices_mxid self.third_party_event_rules = hs.get_third_party_event_rules() @@ -185,6 +186,7 @@ class RoomCreationHandler(BaseHandler): ShadowBanError if the requester is shadow-banned. """ user_id = requester.user.to_string() + assert self.hs.is_mine_id(user_id), "User must be our own: %s" % (user_id,) # start by allocating a new room id r = await self.store.get_room(old_room_id) @@ -213,7 +215,6 @@ class RoomCreationHandler(BaseHandler): "replacement_room": new_room_id, }, }, - token_id=requester.access_token_id, ) old_room_version = await self.store.get_room_version_id(old_room_id) await self.auth.check_from_context( @@ -229,8 +230,8 @@ class RoomCreationHandler(BaseHandler): ) # now send the tombstone - await self.event_creation_handler.send_nonmember_event( - requester, tombstone_event, tombstone_context + await self.event_creation_handler.handle_new_client_event( + requester=requester, event=tombstone_event, context=tombstone_context, ) old_room_state = await tombstone_context.get_current_state_ids() @@ -358,7 +359,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - if not self.spam_checker.user_may_create_room(user_id): + if not await self.spam_checker.user_may_create_room(user_id): raise SynapseError(403, "You are not permitted to create rooms") creation_content = { @@ -440,6 +441,7 @@ class RoomCreationHandler(BaseHandler): invite_list=[], initial_state=initial_state, creation_content=creation_content, + ratelimit=False, ) # Transfer membership events @@ -587,7 +589,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - await self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(requester=requester) if ( self._server_notices_mxid is not None @@ -608,7 +610,7 @@ class RoomCreationHandler(BaseHandler): 403, "You are not permitted to create rooms", Codes.FORBIDDEN ) - if not is_requester_admin and not self.spam_checker.user_may_create_room( + if not is_requester_admin and not await self.spam_checker.user_may_create_room( user_id ): raise SynapseError(403, "You are not permitted to create rooms") @@ -681,7 +683,16 @@ class RoomCreationHandler(BaseHandler): creator_id=user_id, is_public=is_public, room_version=room_version, ) - directory_handler = self.hs.get_handlers().directory_handler + # Check whether this visibility value is blocked by a third party module + allowed_by_third_party_rules = await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) + ) + if not allowed_by_third_party_rules: + raise SynapseError(403, "Room visibility value not allowed.") + + directory_handler = self.hs.get_directory_handler() if room_alias: await directory_handler.create_association( requester=requester, @@ -726,6 +737,7 @@ class RoomCreationHandler(BaseHandler): room_alias=room_alias, power_level_content_override=power_level_content_override, creator_join_profile=creator_join_profile, + ratelimit=ratelimit, ) if "name" in config: @@ -762,22 +774,29 @@ class RoomCreationHandler(BaseHandler): ratelimit=False, ) - for invitee in invite_list: + # we avoid dropping the lock between invites, as otherwise joins can + # start coming in and making the createRoom slow. + # + # we also don't need to check the requester's shadow-ban here, as we + # have already done so above (and potentially emptied invite_list). + with (await self.room_member_handler.member_linearizer.queue((room_id,))): content = {} is_direct = config.get("is_direct", None) if is_direct: content["is_direct"] = is_direct - # Note that update_membership with an action of "invite" can raise a - # ShadowBanError, but this was handled above by emptying invite_list. - _, last_stream_id = await self.room_member_handler.update_membership( - requester, - UserID.from_string(invitee), - room_id, - "invite", - ratelimit=False, - content=content, - ) + for invitee in invite_list: + ( + _, + last_stream_id, + ) = await self.room_member_handler.update_membership_locked( + requester, + UserID.from_string(invitee), + room_id, + "invite", + ratelimit=False, + content=content, + ) for invite_3pid in invite_3pid_list: id_server = invite_3pid["id_server"] @@ -822,6 +841,7 @@ class RoomCreationHandler(BaseHandler): room_alias: Optional[RoomAlias] = None, power_level_content_override: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None, + ratelimit: bool = True, ) -> int: """Sends the initial events into a new room. @@ -868,7 +888,7 @@ class RoomCreationHandler(BaseHandler): creator.user, room_id, "join", - ratelimit=False, + ratelimit=ratelimit, content=creator_join_profile, ) @@ -962,8 +982,6 @@ class RoomCreationHandler(BaseHandler): try: random_string = stringutils.random_string(18) gen_room_id = RoomID(random_string, self.hs.hostname).to_string() - if isinstance(gen_room_id, bytes): - gen_room_id = gen_room_id.decode("utf-8") await self.store.store_room( room_id=gen_room_id, room_creator_user_id=creator_id, @@ -1243,7 +1261,9 @@ class RoomShutdownHandler: 400, "User must be our own: %s" % (new_room_user_id,) ) - room_creator_requester = create_requester(new_room_user_id) + room_creator_requester = create_requester( + new_room_user_id, authenticated_entity=requester_user_id + ) info, stream_id = await self._room_creation_handler.create_room( room_creator_requester, @@ -1261,7 +1281,7 @@ class RoomShutdownHandler: ) # We now wait for the create room to come back in via replication so - # that we can assume that all the joins/invites have propogated before + # that we can assume that all the joins/invites have propagated before # we try and auto join below. await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(new_room_id), @@ -1283,7 +1303,9 @@ class RoomShutdownHandler: try: # Kick users from room - target_requester = create_requester(user_id) + target_requester = create_requester( + user_id, authenticated_entity=requester_user_id + ) _, stream_id = await self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 4a13c8e912..14f14db449 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py
@@ -15,19 +15,22 @@ import logging from collections import namedtuple -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Optional, Tuple import msgpack from unpaddedbase64 import decode_base64, encode_base64 -from synapse.api.constants import EventTypes, JoinRules +from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules from synapse.api.errors import Codes, HttpResponseException -from synapse.types import ThirdPartyInstanceID +from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.descriptors import cached from synapse.util.caches.response_cache import ResponseCache from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 @@ -37,37 +40,38 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search - self.response_cache = ResponseCache(hs, "room_list") + self.response_cache = ResponseCache( + hs, "room_list" + ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]] self.remote_response_cache = ResponseCache( hs, "remote_room_list", timeout_ms=30 * 1000 - ) + ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] async def get_local_public_room_list( self, - limit=None, - since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False, - ): + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[dict] = None, + network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, + from_federation: bool = False, + ) -> JsonDict: """Generate a local public room list. There are multiple different lists: the main one plus one per third party network. A client can ask for a specific list or to return all. Args: - limit (int|None) - since_token (str|None) - search_filter (dict|None) - network_tuple (ThirdPartyInstanceID): Which public list to use. + limit + since_token + search_filter + network_tuple: Which public list to use. This can be (None, None) to indicate the main list, or a particular appservice and network id to use an appservice specific one. Setting to None returns all public rooms across all lists. - from_federation (bool): true iff the request comes from the federation - API + from_federation: true iff the request comes from the federation API """ if not self.enable_room_list_search: return {"chunk": [], "total_room_count_estimate": 0} @@ -107,10 +111,10 @@ class RoomListHandler(BaseHandler): self, limit: Optional[int] = None, since_token: Optional[str] = None, - search_filter: Optional[Dict] = None, + search_filter: Optional[dict] = None, network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID, from_federation: bool = False, - ) -> Dict[str, Any]: + ) -> JsonDict: """Generate a public room list. Args: limit: Maximum amount of rooms to return. @@ -131,13 +135,17 @@ class RoomListHandler(BaseHandler): if since_token: batch_token = RoomListNextBatch.from_token(since_token) - bounds = (batch_token.last_joined_members, batch_token.last_room_id) + bounds = ( + batch_token.last_joined_members, + batch_token.last_room_id, + ) # type: Optional[Tuple[int, str]] forwards = batch_token.direction_is_forward + has_batch_token = True else: - batch_token = None bounds = None forwards = True + has_batch_token = False # we request one more than wanted to see if there are more pages to come probing_limit = limit + 1 if limit is not None else None @@ -159,7 +167,8 @@ class RoomListHandler(BaseHandler): "canonical_alias": room["canonical_alias"], "num_joined_members": room["joined_members"], "avatar_url": room["avatar"], - "world_readable": room["history_visibility"] == "world_readable", + "world_readable": room["history_visibility"] + == HistoryVisibility.WORLD_READABLE, "guest_can_join": room["guest_access"] == "can_join", } @@ -168,7 +177,7 @@ class RoomListHandler(BaseHandler): results = [build_room_entry(r) for r in results] - response = {} + response = {} # type: JsonDict num_results = len(results) if limit is not None: more_to_come = num_results == probing_limit @@ -186,7 +195,7 @@ class RoomListHandler(BaseHandler): initial_entry = results[0] if forwards: - if batch_token: + if has_batch_token: # If there was a token given then we assume that there # must be previous results. response["prev_batch"] = RoomListNextBatch( @@ -202,7 +211,7 @@ class RoomListHandler(BaseHandler): direction_is_forward=True, ).to_token() else: - if batch_token: + if has_batch_token: response["next_batch"] = RoomListNextBatch( last_joined_members=final_entry["num_joined_members"], last_room_id=final_entry["room_id"], @@ -292,7 +301,7 @@ class RoomListHandler(BaseHandler): return None # Return whether this room is open to federation users or not - create_event = current_state.get((EventTypes.Create, "")) + create_event = current_state[EventTypes.Create, ""] result["m.federate"] = create_event.content.get("m.federate", True) name_event = current_state.get((EventTypes.Name, "")) @@ -317,7 +326,7 @@ class RoomListHandler(BaseHandler): visibility = None if visibility_event: visibility = visibility_event.content.get("history_visibility", None) - result["world_readable"] = visibility == "world_readable" + result["world_readable"] = visibility == HistoryVisibility.WORLD_READABLE guest_event = current_state.get((EventTypes.GuestAccess, "")) guest = None @@ -335,13 +344,13 @@ class RoomListHandler(BaseHandler): async def get_remote_public_room_list( self, - server_name, - limit=None, - since_token=None, - search_filter=None, - include_all_networks=False, - third_party_instance_id=None, - ): + server_name: str, + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[dict] = None, + include_all_networks: bool = False, + third_party_instance_id: Optional[str] = None, + ) -> JsonDict: if not self.enable_room_list_search: return {"chunk": [], "total_room_count_estimate": 0} @@ -398,13 +407,13 @@ class RoomListHandler(BaseHandler): async def _get_remote_list_cached( self, - server_name, - limit=None, - since_token=None, - search_filter=None, - include_all_networks=False, - third_party_instance_id=None, - ): + server_name: str, + limit: Optional[int] = None, + since_token: Optional[str] = None, + search_filter: Optional[dict] = None, + include_all_networks: bool = False, + third_party_instance_id: Optional[str] = None, + ) -> JsonDict: repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search @@ -455,24 +464,24 @@ class RoomListNextBatch( REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()} @classmethod - def from_token(cls, token): + def from_token(cls, token: str) -> "RoomListNextBatch": decoded = msgpack.loads(decode_base64(token), raw=False) return RoomListNextBatch( **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()} ) - def to_token(self): + def to_token(self) -> str: return encode_base64( msgpack.dumps( {self.KEY_DICT[key]: val for key, val in self._asdict().items()} ) ) - def copy_and_replace(self, **kwds): + def copy_and_replace(self, **kwds) -> "RoomListNextBatch": return self._replace(**kwds) -def _matches_room_entry(room_entry, search_filter): +def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool: if search_filter and search_filter.get("generic_search_term", None): generic_search_term = search_filter["generic_search_term"].upper() if generic_search_term in room_entry.get("name", "").upper(): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8feba8c90a..cb5a29bc7e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -17,12 +17,10 @@ import abc import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union - -from unpaddedbase64 import encode_base64 +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types -from synapse.api.constants import MAX_DEPTH, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -31,13 +29,8 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.api.room_versions import EventFormatVersions -from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase -from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext -from synapse.events.validator import EventValidator -from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room @@ -64,9 +57,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.state_handler = hs.get_state_handler() self.config = hs.config - self.federation_handler = hs.get_handlers().federation_handler - self.directory_handler = hs.get_handlers().directory_handler - self.identity_handler = hs.get_handlers().identity_handler + self.federation_handler = hs.get_federation_handler() + self.directory_handler = hs.get_directory_handler() + self.identity_handler = hs.get_identity_handler() self.registration_handler = hs.get_registration_handler() self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() @@ -171,6 +164,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if requester.is_guest: content["kind"] = "guest" + # Check if we already have an event with a matching transaction ID. (We + # do this check just before we persist an event as well, but may as well + # do it up front for efficiency.) + if txn_id and requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + room_id, requester.user.to_string(), requester.access_token_id, txn_id, + ) + if existing_event_id: + event_pos = await self.store.get_position_for_event(existing_event_id) + return existing_event_id, event_pos.stream + event, context = await self.event_creation_handler.create_event( requester, { @@ -182,21 +186,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # For backwards compatibility: "membership": membership, }, - token_id=requester.access_token_id, txn_id=txn_id, prev_event_ids=prev_event_ids, require_consent=require_consent, ) - # Check if this event matches the previous membership event for the user. - duplicate = await self.event_creation_handler.deduplicate_state_event( - event, context - ) - if duplicate is not None: - # Discard the new event since this membership change is a no-op. - _, stream_id = await self.store.get_event_ordering(duplicate.event_id) - return duplicate.event_id, stream_id - prev_state_ids = await context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -209,7 +203,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Only rate-limit if the user actually joined the room, otherwise we'll end # up blocking profile updates. - if newly_joined: + if newly_joined and ratelimit: time_now_s = self.clock.time() ( allowed, @@ -221,7 +215,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): retry_after_ms=int(1000 * (time_allowed - time_now_s)) ) - stream_id = await self.event_creation_handler.handle_new_client_event( + result_event = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit, ) @@ -231,7 +225,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target, room_id) - return event.event_id, stream_id + # we know it was persisted, so should have a stream ordering + assert result_event.internal_metadata.stream_ordering + return result_event.event_id, result_event.internal_metadata.stream_ordering async def copy_room_tags_and_direct_to_room( self, old_room_id, new_room_id, user_id @@ -247,7 +243,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): user_account_data, _ = await self.store.get_account_data_for_user(user_id) # Copy direct message state if applicable - direct_rooms = user_account_data.get("m.direct", {}) + direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {}) # Check which key this room is under if isinstance(direct_rooms, dict): @@ -258,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Save back to user's m.direct account data await self.store.add_account_data_for_user( - user_id, "m.direct", direct_rooms + user_id, AccountDataTypes.DIRECT, direct_rooms ) break @@ -310,7 +306,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): key = (room_id,) with (await self.member_linearizer.queue(key)): - result = await self._update_membership( + result = await self.update_membership_locked( requester, target, room_id, @@ -325,7 +321,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): return result - async def _update_membership( + async def update_membership_locked( self, requester: Requester, target: UserID, @@ -338,6 +334,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content: Optional[dict] = None, require_consent: bool = True, ) -> Tuple[str, int]: + """Helper for update_membership. + + Assumes that the membership linearizer is already held for the room. + """ content_specified = bool(content) if content is None: content = {} @@ -346,7 +346,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # later on. content = dict(content) - if not self.allow_per_room_profiles or requester.shadow_banned: + # allow the server notices mxid to set room-level profile + is_requester_server_notices_user = ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ) + + if ( + not self.allow_per_room_profiles and not is_requester_server_notices_user + ) or requester.shadow_banned: # Strip profile data, knowing that new profile data will be added to the # event's content in event_creation_handler.create_event() using the target's # global profile. @@ -400,7 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) block_invite = True - if not self.spam_checker.user_may_invite( + if not await self.spam_checker.user_may_invite( requester.user.to_string(), target.to_string(), room_id ): logger.info("Blocking invite due to spam checker") @@ -441,12 +449,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): same_membership = old_membership == effective_membership_state same_sender = requester.user.to_string() == old_state.sender if same_sender and same_membership and same_content: - _, stream_id = await self.store.get_event_ordering( - old_state.event_id - ) + # duplicate event. + # we know it was persisted, so must have a stream ordering. + assert old_state.internal_metadata.stream_ordering return ( old_state.event_id, - stream_id, + old_state.internal_metadata.stream_ordering, ) if old_membership in ["ban", "leave"] and action == "kick": @@ -480,17 +488,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise AuthError(403, "Guest access not allowed") if not is_host_in_room: - time_now_s = self.clock.time() - ( - allowed, - time_allowed, - ) = self._join_rate_limiter_remote.can_requester_do_action(requester,) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)) + if ratelimit: + time_now_s = self.clock.time() + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_remote.can_requester_do_action( + requester, ) + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now_s)) + ) + inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) @@ -514,10 +525,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited - invite = await self.store.get_invite_for_local_user_in_room( - user_id=target.to_string(), room_id=room_id - ) # type: Optional[RoomsForUser] - if not invite: + ( + current_membership_type, + current_membership_event_id, + ) = await self.store.get_local_current_membership_for_user_in_room( + target.to_string(), room_id + ) + if ( + current_membership_type != Membership.INVITE + or not current_membership_event_id + ): logger.info( "%s sent a leave request to %s, but that is not an active room " "on this server, and there is no pending invite", @@ -527,6 +544,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(404, "Not a known room") + invite = await self.store.get_event(current_membership_event_id) logger.info( "%s rejects invite to %s from %s", target, room_id, invite.sender ) @@ -642,7 +660,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def send_membership_event( self, - requester: Requester, + requester: Optional[Requester], event: EventBase, context: EventContext, ratelimit: bool = True, @@ -672,12 +690,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): else: requester = types.create_requester(target_user) - prev_event = await self.event_creation_handler.deduplicate_state_event( - event, context - ) - if prev_event is not None: - return - prev_state_ids = await context.get_prev_state_ids() if event.membership == Membership.JOIN: if requester.is_guest: @@ -692,7 +704,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - await self.event_creation_handler.handle_new_client_event( + event = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target_user], ratelimit=ratelimit ) @@ -970,6 +982,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): self.distributor = hs.get_distributor() self.distributor.declare("user_left_room") + self._server_name = hs.hostname async def _is_remote_room_too_complex( self, room_id: str, remote_room_hosts: List[str] @@ -1064,7 +1077,9 @@ class RoomMemberMasterHandler(RoomMemberHandler): return event_id, stream_id # The room is too large. Leave. - requester = types.create_requester(user, None, False, False, None) + requester = types.create_requester( + user, authenticated_entity=self._server_name + ) await self.update_membership( requester=requester, target=user, room_id=room_id, action="leave" ) @@ -1109,57 +1124,38 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - return await self._locally_reject_invite( + return await self._generate_local_out_of_band_leave( invite_event, txn_id, requester, content ) - async def _locally_reject_invite( + async def _generate_local_out_of_band_leave( self, - invite_event: EventBase, + previous_membership_event: EventBase, txn_id: Optional[str], requester: Requester, content: JsonDict, ) -> Tuple[str, int]: - """Generate a local invite rejection + """Generate a local leave event for a room - This is called after we fail to reject an invite via a remote server. It - generates an out-of-band membership event locally. + This can be called after we e.g fail to reject an invite via a remote server. + It generates an out-of-band membership event locally. Args: - invite_event: the invite to be rejected + previous_membership_event: the previous membership event for this user txn_id: optional transaction ID supplied by the client - requester: user making the rejection request, according to the access token - content: additional content to include in the rejection event. + requester: user making the request, according to the access token + content: additional content to include in the leave event. Normally an empty dict. - """ - room_id = invite_event.room_id - target_user = invite_event.state_key - room_version = await self.store.get_room_version(room_id) + Returns: + A tuple containing (event_id, stream_id of the leave event) + """ + room_id = previous_membership_event.room_id + target_user = previous_membership_event.state_key content["membership"] = Membership.LEAVE - # the auth events for the new event are the same as that of the invite, plus - # the invite itself. - # - # the prev_events are just the invite. - invite_hash = invite_event.event_id # type: Union[str, Tuple] - if room_version.event_format == EventFormatVersions.V1: - alg, h = compute_event_reference_hash(invite_event) - invite_hash = (invite_event.event_id, {alg: encode_base64(h)}) - - auth_events = tuple(invite_event.auth_events) + (invite_hash,) - prev_events = (invite_hash,) - - # we cap depth of generated events, to ensure that they are not - # rejected by other servers (and so that they can be persisted in - # the db) - depth = min(invite_event.depth + 1, MAX_DEPTH) - event_dict = { - "depth": depth, - "auth_events": auth_events, - "prev_events": prev_events, "type": EventTypes.Member, "room_id": room_id, "sender": target_user, @@ -1167,28 +1163,30 @@ class RoomMemberMasterHandler(RoomMemberHandler): "state_key": target_user, } - event = create_local_event_from_event_dict( - clock=self.clock, - hostname=self.hs.hostname, - signing_key=self.hs.signing_key, - room_version=room_version, - event_dict=event_dict, + # the auth events for the new event are the same as that of the previous event, plus + # the event itself. + # + # the prev_events consist solely of the previous membership event. + prev_event_ids = [previous_membership_event.event_id] + auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids + + event, context = await self.event_creation_handler.create_event( + requester, + event_dict, + txn_id=txn_id, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, ) event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True - if txn_id is not None: - event.internal_metadata.txn_id = txn_id - if requester.access_token_id is not None: - event.internal_metadata.token_id = requester.access_token_id - EventValidator().validate_new(event, self.config) - - context = await self.state_handler.compute_event_context(event) - context.app_service = requester.app_service - stream_id = await self.event_creation_handler.handle_new_client_event( + result_event = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[UserID.from_string(target_user)], ) - return event.event_id, stream_id + # we know it was persisted, so must have a stream ordering + assert result_event.internal_metadata.stream_ordering + + return result_event.event_id, result_event.internal_metadata.stream_ordering async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 285c481a96..5fa7ab3f8b 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py
@@ -24,7 +24,8 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.saml2_config import SamlAttributeRequirement -from synapse.http.server import respond_with_html +from synapse.handlers._base import BaseHandler +from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi @@ -33,19 +34,14 @@ from synapse.types import ( map_username_to_mxid_localpart, mxid_localpart_allowed_characters, ) -from synapse.util.async_helpers import Linearizer from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: - import synapse.server + from synapse.server import HomeServer logger = logging.getLogger(__name__) -class MappingException(Exception): - """Used to catch errors when mapping the SAML2 response to a user.""" - - @attr.s(slots=True) class Saml2SessionData: """Data we track about SAML2 sessions""" @@ -57,17 +53,12 @@ class Saml2SessionData: ui_auth_session_id = attr.ib(type=Optional[str], default=None) -class SamlHandler: - def __init__(self, hs: "synapse.server.HomeServer"): - self.hs = hs +class SamlHandler(BaseHandler): + def __init__(self, hs: "HomeServer"): + super().__init__(hs) self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._auth = hs.get_auth() - self._auth_handler = hs.get_auth_handler() - self._registration_handler = hs.get_registration_handler() + self._saml_idp_entityid = hs.config.saml2_idp_entityid - self._clock = hs.get_clock() - self._datastore = hs.get_datastore() - self._hostname = hs.hostname self._saml2_session_lifetime = hs.config.saml2_session_lifetime self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute @@ -87,27 +78,7 @@ class SamlHandler: # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] - # a lock on the mappings - self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) - - def _render_error( - self, request, error: str, error_description: Optional[str] = None - ) -> None: - """Render the error template and respond to the request with it. - - This is used to show errors to the user. The template of this page can - be found under `synapse/res/templates/sso_error.html`. - - Args: - request: The incoming request from the browser. - We'll respond with an HTML page describing the error. - error: A technical identifier for this error. - error_description: A human-readable description of the error. - """ - html = self._error_template.render( - error=error, error_description=error_description - ) - respond_with_html(request, 400, html) + self._sso_handler = hs.get_sso_handler() def handle_redirect_request( self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None @@ -124,13 +95,13 @@ class SamlHandler: URL to redirect to """ reqid, info = self._saml_client.prepare_for_authenticate( - relay_state=client_redirect_url + entityid=self._saml_idp_entityid, relay_state=client_redirect_url ) # Since SAML sessions timeout it is useful to log when they were created. logger.info("Initiating a new SAML session: %s" % (reqid,)) - now = self._clock.time_msec() + now = self.clock.time_msec() self._outstanding_requests_dict[reqid] = Saml2SessionData( creation_time=now, ui_auth_session_id=ui_auth_session_id, ) @@ -171,12 +142,12 @@ class SamlHandler: # in the (user-visible) exception message, so let's log the exception here # so we can track down the session IDs later. logger.warning(str(e)) - self._render_error( + self._sso_handler.render_error( request, "unsolicited_response", "Unexpected SAML2 login." ) return except Exception as e: - self._render_error( + self._sso_handler.render_error( request, "invalid_response", "Unable to parse SAML2 response: %s." % (e,), @@ -184,12 +155,35 @@ class SamlHandler: return if saml2_auth.not_signed: - self._render_error( + self._sso_handler.render_error( request, "unsigned_respond", "SAML2 response was not signed." ) return logger.debug("SAML2 response: %s", saml2_auth.origxml) + + await self._handle_authn_response(request, saml2_auth, relay_state) + + async def _handle_authn_response( + self, + request: SynapseRequest, + saml2_auth: saml2.response.AuthnResponse, + relay_state: str, + ) -> None: + """Handle an AuthnResponse, having parsed it from the request params + + Assumes that the signature on the response object has been checked. Maps + the user onto an MXID, registering them if necessary, and returns a response + to the browser. + + Args: + request: the incoming request from the browser. We'll respond to it with an + HTML page or a redirect + saml2_auth: the parsed AuthnResponse object + relay_state: the RelayState query param, which encodes the URI to rediret + back to + """ + for assertion in saml2_auth.assertions: # kibana limits the length of a log field, whereas this is all rather # useful, so split it up. @@ -206,88 +200,86 @@ class SamlHandler: saml2_auth.in_response_to, None ) + # first check if we're doing a UIA + if current_session and current_session.ui_auth_session_id: + try: + remote_user_id = self._remote_id_from_saml_response(saml2_auth, None) + except MappingException as e: + logger.exception("Failed to extract remote user id from SAML response") + self._sso_handler.render_error(request, "mapping_error", str(e)) + return + + return await self._sso_handler.complete_sso_ui_auth_request( + self._auth_provider_id, + remote_user_id, + current_session.ui_auth_session_id, + request, + ) + + # otherwise, we're handling a login request. + # Ensure that the attributes of the logged in user meet the required # attributes. for requirement in self._saml2_attribute_requirements: if not _check_attribute_requirement(saml2_auth.ava, requirement): - self._render_error( + self._sso_handler.render_error( request, "unauthorised", "You are not authorised to log in here." ) return - # Pull out the user-agent and IP from the request. - user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ - 0 - ].decode("ascii", "surrogateescape") - ip_address = self.hs.get_ip_from_request(request) - # Call the mapper to register/login the user try: - user_id = await self._map_saml_response_to_user( - saml2_auth, relay_state, user_agent, ip_address - ) + await self._complete_saml_login(saml2_auth, request, relay_state) except MappingException as e: logger.exception("Could not map user") - self._render_error(request, "mapping_error", str(e)) - return + self._sso_handler.render_error(request, "mapping_error", str(e)) - # Complete the interactive auth session or the login. - if current_session and current_session.ui_auth_session_id: - await self._auth_handler.complete_sso_ui_auth( - user_id, current_session.ui_auth_session_id, request - ) - - else: - await self._auth_handler.complete_sso_login(user_id, request, relay_state) - - async def _map_saml_response_to_user( + async def _complete_saml_login( self, saml2_auth: saml2.response.AuthnResponse, + request: SynapseRequest, client_redirect_url: str, - user_agent: str, - ip_address: str, - ) -> str: + ) -> None: """ - Given a SAML response, retrieve the user ID for it and possibly register the user. + Given a SAML response, complete the login flow + + Retrieves the remote user ID, registers the user if necessary, and serves + a redirect back to the client with a login-token. Args: saml2_auth: The parsed SAML2 response. + request: The request to respond to client_redirect_url: The redirect URL passed in by the client. - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. - - Returns: - The user ID associated with this response. Raises: MappingException if there was a problem mapping the response to a user. RedirectException: some mapping providers may raise this if they need to redirect to an interstitial page. """ - - remote_user_id = self._user_mapping_provider.get_remote_user_id( + remote_user_id = self._remote_id_from_saml_response( saml2_auth, client_redirect_url ) - if not remote_user_id: - raise MappingException( - "Failed to extract remote user id from SAML response" + async def saml_response_to_remapped_user_attributes( + failures: int, + ) -> UserAttributes: + """ + Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form. + + This is backwards compatibility for abstraction for the SSO handler. + """ + # Call the mapping provider. + result = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, failures, client_redirect_url ) - - with (await self._mapping_lock.queue(self._auth_provider_id)): - # first of all, check if we already have a mapping for this user - logger.info( - "Looking for existing mapping for user %s:%s", - self._auth_provider_id, - remote_user_id, - ) - registered_user_id = await self._datastore.get_user_by_external_id( - self._auth_provider_id, remote_user_id + # Remap some of the results. + return UserAttributes( + localpart=result.get("mxid_localpart"), + display_name=result.get("displayname"), + emails=result.get("emails", []), ) - if registered_user_id is not None: - logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id + async def grandfather_existing_users() -> Optional[str]: # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid if ( @@ -296,75 +288,63 @@ class SamlHandler: ): attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0] user_id = UserID( - map_username_to_mxid_localpart(attrval), self._hostname + map_username_to_mxid_localpart(attrval), self.server_name ).to_string() - logger.info( + + logger.debug( "Looking for existing account based on mapped %s %s", self._grandfathered_mxid_source_attribute, user_id, ) - users = await self._datastore.get_users_by_id_case_insensitive(user_id) + users = await self.store.get_users_by_id_case_insensitive(user_id) if users: registered_user_id = list(users.keys())[0] logger.info("Grandfathering mapping to %s", registered_user_id) - await self._datastore.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id - ) return registered_user_id - # Map saml response to user attributes using the configured mapping provider - for i in range(1000): - attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( - saml2_auth, i, client_redirect_url=client_redirect_url, - ) + return None - logger.debug( - "Retrieved SAML attributes from user mapping provider: %s " - "(attempt %d)", - attribute_dict, - i, - ) + await self._sso_handler.complete_sso_login_request( + self._auth_provider_id, + remote_user_id, + request, + client_redirect_url, + saml_response_to_remapped_user_attributes, + grandfather_existing_users, + ) - localpart = attribute_dict.get("mxid_localpart") - if not localpart: - raise MappingException( - "Error parsing SAML2 response: SAML mapping provider plugin " - "did not return a mxid_localpart value" - ) - - displayname = attribute_dict.get("displayname") - emails = attribute_dict.get("emails", []) - - # Check if this mxid already exists - if not await self._datastore.get_users_by_id_case_insensitive( - UserID(localpart, self._hostname).to_string() - ): - # This mxid is free - break - else: - # Unable to generate a username in 1000 iterations - # Break and return error to the user - raise MappingException( - "Unable to generate a Matrix ID from the SAML response" - ) + def _remote_id_from_saml_response( + self, + saml2_auth: saml2.response.AuthnResponse, + client_redirect_url: Optional[str], + ) -> str: + """Extract the unique remote id from a SAML2 AuthnResponse - logger.info("Mapped SAML user to local part %s", localpart) + Args: + saml2_auth: The parsed SAML2 response. + client_redirect_url: The redirect URL passed in by the client. + Returns: + remote user id - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=displayname, - bind_emails=emails, - user_agent_ips=(user_agent, ip_address), - ) + Raises: + MappingException if there was an error extracting the user id + """ + # It's not obvious why we need to pass in the redirect URI to the mapping + # provider, but we do :/ + remote_user_id = self._user_mapping_provider.get_remote_user_id( + saml2_auth, client_redirect_url + ) - await self._datastore.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id + if not remote_user_id: + raise MappingException( + "Failed to extract remote user id from SAML response" ) - return registered_user_id + + return remote_user_id def expire_sessions(self): - expire_before = self._clock.time_msec() - self._saml2_session_lifetime + expire_before = self.clock.time_msec() - self._saml2_session_lifetime to_expire = set() for reqid, data in self._outstanding_requests_dict.items(): if data.creation_time < expire_before: @@ -476,11 +456,11 @@ class DefaultSamlMappingProvider: ) # Use the configured mapper for this mxid_source - base_mxid_localpart = self._mxid_mapper(mxid_source) + localpart = self._mxid_mapper(mxid_source) # Append suffix integer if last call to this function failed to produce - # a usable mxid - localpart = base_mxid_localpart + (str(failures) if failures else "") + # a usable mxid. + localpart += str(failures) if failures else "" # Retrieve the display name from the saml response # If displayname is None, the mxid_localpart will be used instead diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index e9402e6e2e..66f1bbcfc4 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py
@@ -139,7 +139,7 @@ class SearchHandler(BaseHandler): # Filter to apply to results filter_dict = room_cat.get("filter", {}) - # What to order results by (impacts whether pagination can be doen) + # What to order results by (impacts whether pagination can be done) order_by = room_cat.get("order_by", "rank") # Return the current state of the rooms? diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py new file mode 100644
index 0000000000..33cd6bc178 --- /dev/null +++ b/synapse/handlers/sso.py
@@ -0,0 +1,571 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional + +import attr +from typing_extensions import NoReturn + +from twisted.web.http import Request + +from synapse.api.errors import RedirectException, SynapseError +from synapse.http.server import respond_with_html +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters +from synapse.util.async_helpers import Linearizer +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class MappingException(Exception): + """Used to catch errors when mapping an SSO response to user attributes. + + Note that the msg that is raised is shown to end-users. + """ + + +@attr.s +class UserAttributes: + # the localpart of the mxid that the mapper has assigned to the user. + # if `None`, the mapper has not picked a userid, and the user should be prompted to + # enter one. + localpart = attr.ib(type=Optional[str]) + display_name = attr.ib(type=Optional[str], default=None) + emails = attr.ib(type=List[str], default=attr.Factory(list)) + + +@attr.s(slots=True) +class UsernameMappingSession: + """Data we track about SSO sessions""" + + # A unique identifier for this SSO provider, e.g. "oidc" or "saml". + auth_provider_id = attr.ib(type=str) + + # user ID on the IdP server + remote_user_id = attr.ib(type=str) + + # attributes returned by the ID mapper + display_name = attr.ib(type=Optional[str]) + emails = attr.ib(type=List[str]) + + # An optional dictionary of extra attributes to be provided to the client in the + # login response. + extra_login_attributes = attr.ib(type=Optional[JsonDict]) + + # where to redirect the client back to + client_redirect_url = attr.ib(type=str) + + # expiry time for the session, in milliseconds + expiry_time_ms = attr.ib(type=int) + + +# the HTTP cookie used to track the mapping session id +USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" + + +class SsoHandler: + # The number of attempts to ask the mapping provider for when generating an MXID. + _MAP_USERNAME_RETRIES = 1000 + + # the time a UsernameMappingSession remains valid for + _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000 + + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() + self._store = hs.get_datastore() + self._server_name = hs.hostname + self._registration_handler = hs.get_registration_handler() + self._error_template = hs.config.sso_error_template + self._auth_handler = hs.get_auth_handler() + + # a lock on the mappings + self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) + + # a map from session id to session data + self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] + + def render_error( + self, + request: Request, + error: str, + error_description: Optional[str] = None, + code: int = 400, + ) -> None: + """Renders the error template and responds with it. + + This is used to show errors to the user. The template of this page can + be found under `synapse/res/templates/sso_error.html`. + + Args: + request: The incoming request from the browser. + We'll respond with an HTML page describing the error. + error: A technical identifier for this error. + error_description: A human-readable description of the error. + code: The integer error code (an HTTP response code) + """ + html = self._error_template.render( + error=error, error_description=error_description + ) + respond_with_html(request, code, html) + + async def get_sso_user_by_remote_user_id( + self, auth_provider_id: str, remote_user_id: str + ) -> Optional[str]: + """ + Maps the user ID of a remote IdP to a mxid for a previously seen user. + + If the user has not been seen yet, this will return None. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + remote_user_id: The user ID according to the remote IdP. This might + be an e-mail address, a GUID, or some other form. It must be + unique and immutable. + + Returns: + The mxid of a previously seen user. + """ + logger.debug( + "Looking for existing mapping for user %s:%s", + auth_provider_id, + remote_user_id, + ) + + # Check if we already have a mapping for this user. + previously_registered_user_id = await self._store.get_user_by_external_id( + auth_provider_id, remote_user_id, + ) + + # A match was found, return the user ID. + if previously_registered_user_id is not None: + logger.info( + "Found existing mapping for IdP '%s' and remote_user_id '%s': %s", + auth_provider_id, + remote_user_id, + previously_registered_user_id, + ) + return previously_registered_user_id + + # No match. + return None + + async def complete_sso_login_request( + self, + auth_provider_id: str, + remote_user_id: str, + request: SynapseRequest, + client_redirect_url: str, + sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], + grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], + extra_login_attributes: Optional[JsonDict] = None, + ) -> None: + """ + Given an SSO ID, retrieve the user ID for it and possibly register the user. + + This first checks if the SSO ID has previously been linked to a matrix ID, + if it has that matrix ID is returned regardless of the current mapping + logic. + + If a callable is provided for grandfathering users, it is called and can + potentially return a matrix ID to use. If it does, the SSO ID is linked to + this matrix ID for subsequent calls. + + The mapping function is called (potentially multiple times) to generate + a localpart for the user. + + If an unused localpart is generated, the user is registered from the + given user-agent and IP address and the SSO ID is linked to this matrix + ID for subsequent calls. + + Finally, we generate a redirect to the supplied redirect uri, with a login token + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + + remote_user_id: The unique identifier from the SSO provider. + + request: The request to respond to + + client_redirect_url: The redirect URL passed in by the client. + + sso_to_matrix_id_mapper: A callable to generate the user attributes. + The only parameter is an integer which represents the amount of + times the returned mxid localpart mapping has failed. + + It is expected that the mapper can raise two exceptions, which + will get passed through to the caller: + + MappingException if there was a problem mapping the response + to the user. + RedirectException to redirect to an additional page (e.g. + to prompt the user for more information). + + grandfather_existing_users: A callable which can return an previously + existing matrix ID. The SSO ID is then linked to the returned + matrix ID. + + extra_login_attributes: An optional dictionary of extra + attributes to be provided to the client in the login response. + + Raises: + MappingException if there was a problem mapping the response to a user. + RedirectException: if the mapping provider needs to redirect the user + to an additional page. (e.g. to prompt for more information) + + """ + # grab a lock while we try to find a mapping for this user. This seems... + # optimistic, especially for implementations that end up redirecting to + # interstitial pages. + with await self._mapping_lock.queue(auth_provider_id): + # first of all, check if we already have a mapping for this user + user_id = await self.get_sso_user_by_remote_user_id( + auth_provider_id, remote_user_id, + ) + + # Check for grandfathering of users. + if not user_id: + user_id = await grandfather_existing_users() + if user_id: + # Future logins should also match this user ID. + await self._store.record_user_external_id( + auth_provider_id, remote_user_id, user_id + ) + + # Otherwise, generate a new user. + if not user_id: + attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) + + if attributes.localpart is None: + # the mapper doesn't return a username. bail out with a redirect to + # the username picker. + await self._redirect_to_username_picker( + auth_provider_id, + remote_user_id, + attributes, + client_redirect_url, + extra_login_attributes, + ) + + user_id = await self._register_mapped_user( + attributes, + auth_provider_id, + remote_user_id, + request.get_user_agent(""), + request.getClientIP(), + ) + + await self._auth_handler.complete_sso_login( + user_id, request, client_redirect_url, extra_login_attributes + ) + + async def _call_attribute_mapper( + self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], + ) -> UserAttributes: + """Call the attribute mapper function in a loop, until we get a unique userid""" + for i in range(self._MAP_USERNAME_RETRIES): + try: + attributes = await sso_to_matrix_id_mapper(i) + except (RedirectException, MappingException): + # Mapping providers are allowed to issue a redirect (e.g. to ask + # the user for more information) and can issue a mapping exception + # if a name cannot be generated. + raise + except Exception as e: + # Any other exception is unexpected. + raise MappingException( + "Could not extract user attributes from SSO response." + ) from e + + logger.debug( + "Retrieved user attributes from user mapping provider: %r (attempt %d)", + attributes, + i, + ) + + if not attributes.localpart: + # the mapper has not picked a localpart + return attributes + + # Check if this mxid already exists + user_id = UserID(attributes.localpart, self._server_name).to_string() + if not await self._store.get_users_by_id_case_insensitive(user_id): + # This mxid is free + break + else: + # Unable to generate a username in 1000 iterations + # Break and return error to the user + raise MappingException( + "Unable to generate a Matrix ID from the SSO response" + ) + return attributes + + async def _redirect_to_username_picker( + self, + auth_provider_id: str, + remote_user_id: str, + attributes: UserAttributes, + client_redirect_url: str, + extra_login_attributes: Optional[JsonDict], + ) -> NoReturn: + """Creates a UsernameMappingSession and redirects the browser + + Called if the user mapping provider doesn't return a localpart for a new user. + Raises a RedirectException which redirects the browser to the username picker. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + + remote_user_id: The unique identifier from the SSO provider. + + attributes: the user attributes returned by the user mapping provider. + + client_redirect_url: The redirect URL passed in by the client, which we + will eventually redirect back to. + + extra_login_attributes: An optional dictionary of extra + attributes to be provided to the client in the login response. + + Raises: + RedirectException + """ + session_id = random_string(16) + now = self._clock.time_msec() + session = UsernameMappingSession( + auth_provider_id=auth_provider_id, + remote_user_id=remote_user_id, + display_name=attributes.display_name, + emails=attributes.emails, + client_redirect_url=client_redirect_url, + expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS, + extra_login_attributes=extra_login_attributes, + ) + + self._username_mapping_sessions[session_id] = session + logger.info("Recorded registration session id %s", session_id) + + # Set the cookie and redirect to the username picker + e = RedirectException(b"/_synapse/client/pick_username") + e.cookies.append( + b"%s=%s; path=/" + % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) + ) + raise e + + async def _register_mapped_user( + self, + attributes: UserAttributes, + auth_provider_id: str, + remote_user_id: str, + user_agent: str, + ip_address: str, + ) -> str: + """Register a new SSO user. + + This is called once we have successfully mapped the remote user id onto a local + user id, one way or another. + + Args: + attributes: user attributes returned by the user mapping provider, + including a non-empty localpart. + + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + + remote_user_id: The unique identifier from the SSO provider. + + user_agent: The user-agent in the HTTP request (used for potential + shadow-banning.) + + ip_address: The IP address of the requester (used for potential + shadow-banning.) + + Raises: + a MappingException if the localpart is invalid. + + a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart + is already taken. + """ + + # Since the localpart is provided via a potentially untrusted module, + # ensure the MXID is valid before registering. + if not attributes.localpart or contains_invalid_mxid_characters( + attributes.localpart + ): + raise MappingException("localpart is invalid: %s" % (attributes.localpart,)) + + logger.debug("Mapped SSO user to local part %s", attributes.localpart) + registered_user_id = await self._registration_handler.register_user( + localpart=attributes.localpart, + default_display_name=attributes.display_name, + bind_emails=attributes.emails, + user_agent_ips=[(user_agent, ip_address)], + ) + + await self._store.record_user_external_id( + auth_provider_id, remote_user_id, registered_user_id + ) + return registered_user_id + + async def complete_sso_ui_auth_request( + self, + auth_provider_id: str, + remote_user_id: str, + ui_auth_session_id: str, + request: Request, + ) -> None: + """ + Given an SSO ID, retrieve the user ID for it and complete UIA. + + Note that this requires that the user is mapped in the "user_external_ids" + table. This will be the case if they have ever logged in via SAML or OIDC in + recentish synapse versions, but may not be for older users. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + remote_user_id: The unique identifier from the SSO provider. + ui_auth_session_id: The ID of the user-interactive auth session. + request: The request to complete. + """ + + user_id = await self.get_sso_user_by_remote_user_id( + auth_provider_id, remote_user_id, + ) + + if not user_id: + logger.warning( + "Remote user %s/%s has not previously logged in here: UIA will fail", + auth_provider_id, + remote_user_id, + ) + # Let the UIA flow handle this the same as if they presented creds for a + # different user. + user_id = "" + + await self._auth_handler.complete_sso_ui_auth( + user_id, ui_auth_session_id, request + ) + + async def check_username_availability( + self, localpart: str, session_id: str, + ) -> bool: + """Handle an "is username available" callback check + + Args: + localpart: desired localpart + session_id: the session id for the username picker + Returns: + True if the username is available + Raises: + SynapseError if the localpart is invalid or the session is unknown + """ + + # make sure that there is a valid mapping session, to stop people dictionary- + # scanning for accounts + + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if not session: + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + + logger.info( + "[session %s] Checking for availability of username %s", + session_id, + localpart, + ) + + if contains_invalid_mxid_characters(localpart): + raise SynapseError(400, "localpart is invalid: %s" % (localpart,)) + user_id = UserID(localpart, self._server_name).to_string() + user_infos = await self._store.get_users_by_id_case_insensitive(user_id) + + logger.info("[session %s] users: %s", session_id, user_infos) + return not user_infos + + async def handle_submit_username_request( + self, request: SynapseRequest, localpart: str, session_id: str + ) -> None: + """Handle a request to the username-picker 'submit' endpoint + + Will serve an HTTP response to the request. + + Args: + request: HTTP request + localpart: localpart requested by the user + session_id: ID of the username mapping session, extracted from a cookie + """ + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if not session: + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + + logger.info("[session %s] Registering localpart %s", session_id, localpart) + + attributes = UserAttributes( + localpart=localpart, + display_name=session.display_name, + emails=session.emails, + ) + + # the following will raise a 400 error if the username has been taken in the + # meantime. + user_id = await self._register_mapped_user( + attributes, + session.auth_provider_id, + session.remote_user_id, + request.get_user_agent(""), + request.getClientIP(), + ) + + logger.info("[session %s] Registered userid %s", session_id, user_id) + + # delete the mapping session and the cookie + del self._username_mapping_sessions[session_id] + + # delete the cookie + request.addCookie( + USERNAME_MAPPING_SESSION_COOKIE_NAME, + b"", + expires=b"Thu, 01 Jan 1970 00:00:00 GMT", + path=b"/", + ) + + await self._auth_handler.complete_sso_login( + user_id, + request, + session.client_redirect_url, + session.extra_login_attributes, + ) + + def _expire_old_sessions(self): + to_expire = [] + now = int(self._clock.time_msec()) + + for session_id, session in self._username_mapping_sessions.items(): + if session.expiry_time_ms <= now: + to_expire.append(session_id) + + for session_id in to_expire: + logger.info("Expiring mapping session %s", session_id) + del self._username_mapping_sessions[session_id] diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index 7a4ae0727a..fb4f70e8e2 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py
@@ -32,7 +32,7 @@ class StateDeltasHandler: Returns: None if the field in the events either both match `public_value` or if neither do, i.e. there has been no change. - True if it didnt match `public_value` but now does + True if it didn't match `public_value` but now does False if it did match `public_value` but now doesn't """ prev_event = None diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 249ffe2a55..dc62b21c06 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py
@@ -49,7 +49,7 @@ class StatsHandler: # Guard to ensure we only process deltas one at a time self._is_processing = False - if hs.config.stats_enabled: + if self.stats_enabled and hs.config.run_background_tasks: self.notifier.add_replication_callback(self.notify_new_event) # We kick this off so that we don't have to wait for a change before diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index bfe2583002..5c7590f38e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple @@ -21,7 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup import attr from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.events import EventBase from synapse.logging.context import current_context @@ -32,6 +31,7 @@ from synapse.types import ( Collection, JsonDict, MutableStateMap, + Requester, RoomStreamToken, StateMap, StreamToken, @@ -87,7 +87,7 @@ class SyncConfig: class TimelineBatch: prev_batch = attr.ib(type=StreamToken) events = attr.ib(type=List[EventBase]) - limited = attr.ib(bool) + limited = attr.ib(type=bool) def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -201,6 +201,8 @@ class SyncResult: device_lists: List of user_ids whose devices have changed device_one_time_keys_count: Dict of algorithm to count for one time keys for this device + device_unused_fallback_key_types: List of key types that have an unused fallback + key groups: Group updates, if any """ @@ -213,6 +215,7 @@ class SyncResult: to_device = attr.ib(type=List[JsonDict]) device_lists = attr.ib(type=DeviceLists) device_one_time_keys_count = attr.ib(type=JsonDict) + device_unused_fallback_key_types = attr.ib(type=List[str]) groups = attr.ib(type=Optional[GroupsSyncResult]) def __bool__(self) -> bool: @@ -240,7 +243,9 @@ class SyncHandler: self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache = ResponseCache(hs, "sync") + self.response_cache = ResponseCache( + hs, "sync" + ) # type: ResponseCache[Tuple[Any, ...]] self.state = hs.get_state_handler() self.auth = hs.get_auth() self.storage = hs.get_storage() @@ -256,6 +261,7 @@ class SyncHandler: async def wait_for_sync_for_user( self, + requester: Requester, sync_config: SyncConfig, since_token: Optional[StreamToken] = None, timeout: int = 0, @@ -269,7 +275,7 @@ class SyncHandler: # not been exceeded (if not part of the group by this point, almost certain # auth_blocking will occur) user_id = sync_config.user.to_string() - await self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(requester=requester) res = await self.response_cache.wrap( sync_config.request_key, @@ -457,8 +463,13 @@ class SyncHandler: recents = [] if not limited or block_all_timeline: + prev_batch_token = now_token + if recents: + room_key = recents[0].internal_metadata.before + prev_batch_token = now_token.copy_and_replace("room_key", room_key) + return TimelineBatch( - events=recents, prev_batch=now_token, limited=False + events=recents, prev_batch=prev_batch_token, limited=False ) filtering_factor = 2 @@ -543,7 +554,7 @@ class SyncHandler: event.event_id, state_filter=state_filter ) if event.is_state(): - state_ids = state_ids.copy() + state_ids = dict(state_ids) state_ids[(event.type, event.state_key)] = event.event_id return state_ids @@ -745,7 +756,7 @@ class SyncHandler: """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state - # updates even if they occured logically before the previous event. + # updates even if they occurred logically before the previous event. # TODO(mjark) Check for new redactions in the state events. with Measure(self.clock, "compute_state_delta"): @@ -1014,10 +1025,14 @@ class SyncHandler: logger.debug("Fetching OTK data") device_id = sync_config.device_id one_time_key_counts = {} # type: JsonDict + unused_fallback_key_types = [] # type: List[str] if device_id: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) + unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types( + user_id, device_id + ) logger.debug("Fetching group data") await self._generate_sync_entry_for_groups(sync_result_builder) @@ -1041,6 +1056,7 @@ class SyncHandler: device_lists=device_lists, groups=sync_result_builder.groups, device_one_time_keys_count=one_time_key_counts, + device_unused_fallback_key_types=unused_fallback_key_types, next_batch=sync_result_builder.now_token, ) @@ -1378,13 +1394,16 @@ class SyncHandler: return set(), set(), set(), set() ignored_account_data = await self.store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id=user_id + AccountDataTypes.IGNORED_USER_LIST, user_id=user_id ) + # If there is ignored users account data and it matches the proper type, + # then use it. + ignored_users = frozenset() # type: FrozenSet[str] if ignored_account_data: - ignored_users = ignored_account_data.get("ignored_users", {}).keys() - else: - ignored_users = frozenset() + ignored_users_data = ignored_account_data.get("ignored_users", {}) + if isinstance(ignored_users_data, dict): + ignored_users = frozenset(ignored_users_data.keys()) if since_token: room_changes = await self._get_rooms_changed( @@ -1478,7 +1497,7 @@ class SyncHandler: return False async def _get_rooms_changed( - self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str] + self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: """Gets the the changes that have happened since the last sync. """ @@ -1690,7 +1709,7 @@ class SyncHandler: return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms) async def _get_all_rooms( - self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str] + self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: """Returns entries for all rooms for the user. @@ -1764,7 +1783,7 @@ class SyncHandler: async def _generate_room_entry( self, sync_result_builder: "SyncResultBuilder", - ignored_users: Set[str], + ignored_users: FrozenSet[str], room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], tags: Optional[Dict[str, Dict[str, Any]]], @@ -1865,7 +1884,7 @@ class SyncHandler: # members (as the client otherwise doesn't have enough info to form # the name itself). if sync_config.filter_collection.lazy_load_members() and ( - # we recalulate the summary: + # we recalculate the summary: # if there are membership changes in the timeline, or # if membership has changed during a gappy sync, or # if this is an initial sync. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3cbfc2d780..e919a8f9ed 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -12,16 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import random from collections import namedtuple from typing import TYPE_CHECKING, List, Set, Tuple from synapse.api.errors import AuthError, ShadowBanError, SynapseError +from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import TypingStream -from synapse.types import UserID, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -167,20 +167,25 @@ class FollowerTypingHandler: now_typing = set(row.user_ids) self._room_typing[row.room_id] = row.user_ids - run_as_background_process( - "_handle_change_in_typing", - self._handle_change_in_typing, - row.room_id, - prev_typing, - now_typing, - ) + if self.federation: + run_as_background_process( + "_send_changes_in_typing_to_remotes", + self._send_changes_in_typing_to_remotes, + row.room_id, + prev_typing, + now_typing, + ) - async def _handle_change_in_typing( + async def _send_changes_in_typing_to_remotes( self, room_id: str, prev_typing: Set[str], now_typing: Set[str] ): """Process a change in typing of a room from replication, sending EDUs for any local users. """ + + if not self.federation: + return + for user_id in now_typing - prev_typing: if self.is_mine_id(user_id): await self._push_remote(RoomMember(room_id, user_id), True) @@ -371,7 +376,7 @@ class TypingWriterHandler(FollowerTypingHandler): between the requested tokens due to the limit. The token returned can be used in a subsequent call to this - function to get further updatees. + function to get further updates. The updates are a list of 2-tuples of stream ID and the row data """ @@ -430,6 +435,33 @@ class TypingNotificationEventSource: "content": {"user_ids": list(typing)}, } + async def get_new_events_as( + self, from_key: int, service: ApplicationService + ) -> Tuple[List[JsonDict], int]: + """Returns a set of new typing events that an appservice + may be interested in. + + Args: + from_key: the stream position at which events should be fetched from + service: The appservice which may be interested + """ + with Measure(self.clock, "typing.get_new_events_as"): + from_key = int(from_key) + handler = self.get_typing_handler() + + events = [] + for room_id in handler._room_serials.keys(): + if handler._room_serials[room_id] <= from_key: + continue + if not await service.matches_user_in_member_list( + room_id, handler.store + ): + continue + + events.append(self._make_event_for(room_id)) + + return (events, handler._latest_room_serial) + async def get_new_events(self, from_key, room_ids, **kwargs): with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 9146dc1a3b..3d66bf305e 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py
@@ -143,7 +143,7 @@ class _BaseThreepidAuthChecker: threepid_creds = authdict["threepid_creds"] - identity_handler = self.hs.get_handlers().identity_handler + identity_handler = self.hs.get_identity_handler() logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 79393c8829..d4651c8348 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py
@@ -14,14 +14,19 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional import synapse.metrics -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo +from synapse.types import JsonDict from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -31,12 +36,12 @@ class UserDirectoryHandler(StateDeltasHandler): N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY The user directory is filled with users who this server can see are joined to a - world_readable or publically joinable room. We keep a database table up to date + world_readable or publicly joinable room. We keep a database table up to date by streaming changes of the current state and recalculating whether users should be in the directory or not when necessary. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -49,7 +54,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.search_all_users = hs.config.user_directory_search_all_users self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream - self.pos = None + self.pos = None # type: Optional[int] # Guard to ensure we only process deltas one at a time self._is_processing = False @@ -61,7 +66,9 @@ class UserDirectoryHandler(StateDeltasHandler): # we start populating the user directory self.clock.call_later(0, self.notify_new_event) - async def search_users(self, user_id, search_term, limit): + async def search_users( + self, user_id: str, search_term: str, limit: int + ) -> JsonDict: """Searches for users in directory Returns: @@ -81,15 +88,15 @@ class UserDirectoryHandler(StateDeltasHandler): results = await self.store.search_user_dir(user_id, search_term, limit) # Remove any spammy users from the results. - results["results"] = [ - user - for user in results["results"] - if not self.spam_checker.check_username_for_spam(user) - ] + non_spammy_users = [] + for user in results["results"]: + if not await self.spam_checker.check_username_for_spam(user): + non_spammy_users.append(user) + results["results"] = non_spammy_users return results - def notify_new_event(self): + def notify_new_event(self) -> None: """Called when there may be more deltas to process """ if not self.update_user_directory: @@ -107,27 +114,33 @@ class UserDirectoryHandler(StateDeltasHandler): self._is_processing = True run_as_background_process("user_directory.notify_new_event", process) - async def handle_local_profile_change(self, user_id, profile): + async def handle_local_profile_change( + self, user_id: str, profile: ProfileInfo + ) -> None: """Called to update index of our local user profiles when they change irrespective of any rooms the user may be in. """ # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - is_support = await self.store.is_support_user(user_id) + # Support users are for diagnostics and should not appear in the user directory. - if not is_support: + is_support = await self.store.is_support_user(user_id) + # When change profile information of deactivated user it should not appear in the user directory. + is_deactivated = await self.store.get_user_deactivated_status(user_id) + + if not (is_support or is_deactivated): await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) - async def handle_user_deactivated(self, user_id): + async def handle_user_deactivated(self, user_id: str) -> None: """Called when a user ID is deactivated """ # FIXME(#3714): We should probably do this in the same worker as all # the other changes. await self.store.remove_from_user_dir(user_id) - async def _unsafe_process(self): + async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB if self.pos is None: self.pos = await self.store.get_user_directory_stream_pos() @@ -162,7 +175,7 @@ class UserDirectoryHandler(StateDeltasHandler): await self.store.update_user_directory_stream_pos(max_pos) - async def _handle_deltas(self, deltas): + async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None: """Called with the state deltas to process """ for delta in deltas: @@ -232,16 +245,20 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug("Ignoring irrelevant type: %r", typ) async def _handle_room_publicity_change( - self, room_id, prev_event_id, event_id, typ - ): + self, + room_id: str, + prev_event_id: Optional[str], + event_id: Optional[str], + typ: str, + ) -> None: """Handle a room having potentially changed from/to world_readable/publicly joinable. Args: - room_id (str) - prev_event_id (str|None): The previous event before the state change - event_id (str|None): The new event after the state change - typ (str): Type of the event + room_id: The ID of the room which changed. + prev_event_id: The previous event before the state change + event_id: The new event after the state change + typ: Type of the event """ logger.debug("Handling change for %s: %s", typ, room_id) @@ -250,7 +267,7 @@ class UserDirectoryHandler(StateDeltasHandler): prev_event_id, event_id, key_name="history_visibility", - public_value="world_readable", + public_value=HistoryVisibility.WORLD_READABLE, ) elif typ == EventTypes.JoinRules: change = await self._get_key_change( @@ -299,12 +316,14 @@ class UserDirectoryHandler(StateDeltasHandler): for user_id, profile in users_with_profile.items(): await self._handle_new_user(room_id, user_id, profile) - async def _handle_new_user(self, room_id, user_id, profile): + async def _handle_new_user( + self, room_id: str, user_id: str, profile: ProfileInfo + ) -> None: """Called when we might need to add user to directory Args: - room_id (str): room_id that user joined or started being public - user_id (str) + room_id: The room ID that user joined or started being public + user_id """ logger.debug("Adding new user to dir, %r", user_id) @@ -352,12 +371,12 @@ class UserDirectoryHandler(StateDeltasHandler): if to_insert: await self.store.add_users_who_share_private_room(room_id, to_insert) - async def _handle_remove_user(self, room_id, user_id): + async def _handle_remove_user(self, room_id: str, user_id: str) -> None: """Called when we might need to remove user from directory Args: - room_id (str): room_id that user left or stopped being public that - user_id (str) + room_id: The room ID that user left or stopped being public that + user_id """ logger.debug("Removing user %r", user_id) @@ -370,7 +389,13 @@ class UserDirectoryHandler(StateDeltasHandler): if len(rooms_user_is_in) == 0: await self.store.remove_from_user_dir(user_id) - async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id): + async def _handle_profile_change( + self, + user_id: str, + room_id: str, + prev_event_id: Optional[str], + event_id: Optional[str], + ) -> None: """Check member event changes for any profile changes and update the database if there are. """ diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8324632cb6..5f74ee1149 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -14,9 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import urllib +import urllib.parse from io import BytesIO from typing import ( + TYPE_CHECKING, Any, BinaryIO, Dict, @@ -31,7 +32,7 @@ from typing import ( import treq from canonicaljson import encode_canonical_json -from netaddr import IPAddress +from netaddr import IPAddress, IPSet from prometheus_client import Counter from zope.interface import implementer, provider @@ -39,6 +40,8 @@ from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, error as twisted_error, protocol, ssl from twisted.internet.interfaces import ( + IAddress, + IHostResolution, IReactorPluggableNameResolver, IResolutionReceiver, ) @@ -53,7 +56,7 @@ from twisted.web.client import ( ) from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers -from twisted.web.iweb import IResponse +from twisted.web.iweb import IAgent, IBodyProducer, IResponse from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri @@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) @@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]] QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]] -def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): +def check_against_blacklist( + ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet +) -> bool: """ + Compares an IP address to allowed and disallowed IP sets. + Args: - ip_address (netaddr.IPAddress) - ip_whitelist (netaddr.IPSet) - ip_blacklist (netaddr.IPSet) + ip_address: The IP address to check + ip_whitelist: Allowed IP addresses. + ip_blacklist: Disallowed IP addresses. + + Returns: + True if the IP address is in the blacklist and not in the whitelist. """ if ip_address in ip_blacklist: if ip_whitelist is None or ip_address not in ip_whitelist: @@ -112,29 +125,36 @@ def _make_scheduler(reactor): return _scheduler -class IPBlacklistingResolver: +class _IPBlacklistingResolver: """ A proxy for reactor.nameResolver which only produces non-blacklisted IP addresses, preventing DNS rebinding attacks on URL preview. """ - def __init__(self, reactor, ip_whitelist, ip_blacklist): + def __init__( + self, + reactor: IReactorPluggableNameResolver, + ip_whitelist: Optional[IPSet], + ip_blacklist: IPSet, + ): """ Args: - reactor (twisted.internet.reactor) - ip_whitelist (netaddr.IPSet) - ip_blacklist (netaddr.IPSet) + reactor: The twisted reactor. + ip_whitelist: IP addresses to allow. + ip_blacklist: IP addresses to disallow. """ self._reactor = reactor self._ip_whitelist = ip_whitelist self._ip_blacklist = ip_blacklist - def resolveHostName(self, recv, hostname, portNumber=0): + def resolveHostName( + self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0 + ) -> IResolutionReceiver: r = recv() - addresses = [] + addresses = [] # type: List[IAddress] - def _callback(): + def _callback() -> None: r.resolutionBegan(None) has_bad_ip = False @@ -161,15 +181,15 @@ class IPBlacklistingResolver: @provider(IResolutionReceiver) class EndpointReceiver: @staticmethod - def resolutionBegan(resolutionInProgress): + def resolutionBegan(resolutionInProgress: IHostResolution) -> None: pass @staticmethod - def addressResolved(address): + def addressResolved(address: IAddress) -> None: addresses.append(address) @staticmethod - def resolutionComplete(): + def resolutionComplete() -> None: _callback() self._reactor.nameResolver.resolveHostName( @@ -179,25 +199,64 @@ class IPBlacklistingResolver: return r +@implementer(IReactorPluggableNameResolver) +class BlacklistingReactorWrapper: + """ + A Reactor wrapper which will prevent DNS resolution to blacklisted IP + addresses, to prevent DNS rebinding. + """ + + def __init__( + self, + reactor: IReactorPluggableNameResolver, + ip_whitelist: Optional[IPSet], + ip_blacklist: IPSet, + ): + self._reactor = reactor + + # We need to use a DNS resolver which filters out blacklisted IP + # addresses, to prevent DNS rebinding. + self._nameResolver = _IPBlacklistingResolver( + self._reactor, ip_whitelist, ip_blacklist + ) + + def __getattr__(self, attr: str) -> Any: + # Passthrough to the real reactor except for the DNS resolver. + if attr == "nameResolver": + return self._nameResolver + else: + return getattr(self._reactor, attr) + + class BlacklistingAgentWrapper(Agent): """ An Agent wrapper which will prevent access to IP addresses being accessed directly (without an IP address lookup). """ - def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None): + def __init__( + self, + agent: IAgent, + ip_whitelist: Optional[IPSet] = None, + ip_blacklist: Optional[IPSet] = None, + ): """ Args: - agent (twisted.web.client.Agent): The Agent to wrap. - reactor (twisted.internet.reactor) - ip_whitelist (netaddr.IPSet) - ip_blacklist (netaddr.IPSet) + agent: The Agent to wrap. + ip_whitelist: IP addresses to allow. + ip_blacklist: IP addresses to disallow. """ self._agent = agent self._ip_whitelist = ip_whitelist self._ip_blacklist = ip_blacklist - def request(self, method, uri, headers=None, bodyProducer=None): + def request( + self, + method: bytes, + uri: bytes, + headers: Optional[Headers] = None, + bodyProducer: Optional[IBodyProducer] = None, + ) -> defer.Deferred: h = urllib.parse.urlparse(uri.decode("ascii")) try: @@ -226,23 +285,23 @@ class SimpleHttpClient: def __init__( self, - hs, - treq_args={}, - ip_whitelist=None, - ip_blacklist=None, - http_proxy=None, - https_proxy=None, + hs: "HomeServer", + treq_args: Dict[str, Any] = {}, + ip_whitelist: Optional[IPSet] = None, + ip_blacklist: Optional[IPSet] = None, + http_proxy: Optional[bytes] = None, + https_proxy: Optional[bytes] = None, ): """ Args: - hs (synapse.server.HomeServer) - treq_args (dict): Extra keyword arguments to be given to treq.request. - ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that + hs + treq_args: Extra keyword arguments to be given to treq.request. + ip_blacklist: The IP addresses that are blacklisted that we may not request. - ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can + ip_whitelist: The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. - http_proxy (bytes): proxy server to use for http connections. host[:port] - https_proxy (bytes): proxy server to use for https connections. host[:port] + http_proxy: proxy server to use for http connections. host[:port] + https_proxy: proxy server to use for https connections. host[:port] """ self.hs = hs @@ -262,22 +321,11 @@ class SimpleHttpClient: self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: - real_reactor = hs.get_reactor() # If we have an IP blacklist, we need to use a DNS resolver which # filters out blacklisted IP addresses, to prevent DNS rebinding. - nameResolver = IPBlacklistingResolver( - real_reactor, self._ip_whitelist, self._ip_blacklist + self.reactor = BlacklistingReactorWrapper( + hs.get_reactor(), self._ip_whitelist, self._ip_blacklist ) - - @implementer(IReactorPluggableNameResolver) - class Reactor: - def __getattr__(_self, attr): - if attr == "nameResolver": - return nameResolver - else: - return getattr(real_reactor, attr) - - self.reactor = Reactor() else: self.reactor = hs.get_reactor() @@ -293,6 +341,7 @@ class SimpleHttpClient: self.agent = ProxyAgent( self.reactor, + hs.get_reactor(), connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, @@ -306,7 +355,6 @@ class SimpleHttpClient: # by the DNS resolution. self.agent = BlacklistingAgentWrapper( self.agent, - self.reactor, ip_whitelist=self._ip_whitelist, ip_blacklist=self._ip_blacklist, ) @@ -359,7 +407,7 @@ class SimpleHttpClient: agent=self.agent, data=body_producer, headers=headers, - **self._extra_treq_args + **self._extra_treq_args, ) # type: defer.Deferred # we use our own timeout mechanism rather than treq's as a workaround @@ -397,7 +445,7 @@ class SimpleHttpClient: async def post_urlencoded_get_json( self, uri: str, - args: Mapping[str, Union[str, List[str]]] = {}, + args: Optional[Mapping[str, Union[str, List[str]]]] = None, headers: Optional[RawHeaders] = None, ) -> Any: """ @@ -422,9 +470,7 @@ class SimpleHttpClient: # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) - query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode( - "utf8" - ) + query_bytes = encode_query_args(args) actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], @@ -432,7 +478,7 @@ class SimpleHttpClient: b"Accept": [b"application/json"], } if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request( "POST", uri, headers=Headers(actual_headers), data=query_bytes @@ -479,7 +525,7 @@ class SimpleHttpClient: b"Accept": [b"application/json"], } if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request( "POST", uri, headers=Headers(actual_headers), data=json_str @@ -495,7 +541,10 @@ class SimpleHttpClient: ) async def get_json( - self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None, + self, + uri: str, + args: Optional[QueryParams] = None, + headers: Optional[RawHeaders] = None, ) -> Any: """Gets some json from the given URI. @@ -516,7 +565,7 @@ class SimpleHttpClient: """ actual_headers = {b"Accept": [b"application/json"]} if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore body = await self.get_raw(uri, args, headers=headers) return json_decoder.decode(body.decode("utf-8")) @@ -525,7 +574,7 @@ class SimpleHttpClient: self, uri: str, json_body: Any, - args: QueryParams = {}, + args: Optional[QueryParams] = None, headers: RawHeaders = None, ) -> Any: """Puts some json to the given URI. @@ -546,9 +595,9 @@ class SimpleHttpClient: ValueError: if the response was not JSON """ - if len(args): - query_bytes = urllib.parse.urlencode(args, True) - uri = "%s?%s" % (uri, query_bytes) + if args: + query_str = urllib.parse.urlencode(args, True) + uri = "%s?%s" % (uri, query_str) json_str = encode_canonical_json(json_body) @@ -558,7 +607,7 @@ class SimpleHttpClient: b"Accept": [b"application/json"], } if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request( "PUT", uri, headers=Headers(actual_headers), data=json_str @@ -574,7 +623,10 @@ class SimpleHttpClient: ) async def get_raw( - self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None + self, + uri: str, + args: Optional[QueryParams] = None, + headers: Optional[RawHeaders] = None, ) -> bytes: """Gets raw text from the given URI. @@ -592,13 +644,13 @@ class SimpleHttpClient: HttpResponseException on a non-2xx HTTP response. """ - if len(args): - query_bytes = urllib.parse.urlencode(args, True) - uri = "%s?%s" % (uri, query_bytes) + if args: + query_str = urllib.parse.urlencode(args, True) + uri = "%s?%s" % (uri, query_str) actual_headers = {b"User-Agent": [self.user_agent]} if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request("GET", uri, headers=Headers(actual_headers)) @@ -641,7 +693,7 @@ class SimpleHttpClient: actual_headers = {b"User-Agent": [self.user_agent]} if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request("GET", url, headers=Headers(actual_headers)) @@ -649,12 +701,13 @@ class SimpleHttpClient: if ( b"Content-Length" in resp_headers + and max_size and int(resp_headers[b"Content-Length"][0]) > max_size ): - logger.warning("Requested URL is too large > %r bytes" % (self.max_size,)) + logger.warning("Requested URL is too large > %r bytes" % (max_size,)) raise SynapseError( 502, - "Requested file is too large > %r bytes" % (self.max_size,), + "Requested file is too large > %r bytes" % (max_size,), Codes.TOO_LARGE, ) @@ -668,11 +721,14 @@ class SimpleHttpClient: try: length = await make_deferred_yieldable( - _readBodyToFile(response, output_stream, max_size) + read_body_with_max_size(response, output_stream, max_size) + ) + except BodyExceededMaxSize: + SynapseError( + 502, + "Requested file is too large > %r bytes" % (max_size,), + Codes.TOO_LARGE, ) - except SynapseError: - # This can happen e.g. because the body is too large. - raise except Exception as e: raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e @@ -696,32 +752,28 @@ def _timeout_to_request_timed_out_error(f: Failure): return f -# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. -# The two should be factored out. +class BodyExceededMaxSize(Exception): + """The maximum allowed size of the HTTP body was exceeded.""" -class _ReadBodyToFileProtocol(protocol.Protocol): - def __init__(self, stream, deferred, max_size): +class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): + def __init__( + self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] + ): self.stream = stream self.deferred = deferred self.length = 0 self.max_size = max_size - def dataReceived(self, data): + def dataReceived(self, data: bytes) -> None: self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback( - SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - ) - ) + self.deferred.errback(BodyExceededMaxSize()) self.deferred = defer.Deferred() self.transport.loseConnection() - def connectionLost(self, reason): + def connectionLost(self, reason: Failure) -> None: if reason.check(ResponseDone): self.deferred.callback(self.length) elif reason.check(PotentialDataLoss): @@ -732,35 +784,51 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.deferred.errback(reason) -# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. -# The two should be factored out. +def read_body_with_max_size( + response: IResponse, stream: BinaryIO, max_size: Optional[int] +) -> defer.Deferred: + """ + Read a HTTP response body to a file-object. Optionally enforcing a maximum file size. + + If the maximum file size is reached, the returned Deferred will resolve to a + Failure with a BodyExceededMaxSize exception. + + Args: + response: The HTTP response to read from. + stream: The file-object to write to. + max_size: The maximum file size to allow. + Returns: + A Deferred which resolves to the length of the read body. + """ -def _readBodyToFile(response, stream, max_size): d = defer.Deferred() - response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) + response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size)) return d -def encode_urlencode_args(args): - return {k: encode_urlencode_arg(v) for k, v in args.items()} +def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes: + """ + Encodes a map of query arguments to bytes which can be appended to a URL. + + Args: + args: The query arguments, a mapping of string to string or list of strings. + Returns: + The query arguments encoded as bytes. + """ + if args is None: + return b"" -def encode_urlencode_arg(arg): - if isinstance(arg, str): - return arg.encode("utf-8") - elif isinstance(arg, list): - return [encode_urlencode_arg(i) for i in arg] - else: - return arg + encoded_args = {} + for k, vs in args.items(): + if isinstance(vs, str): + vs = [vs] + encoded_args[k] = [v.encode("utf8") for v in vs] + query_str = urllib.parse.urlencode(encoded_args, True) -def _print_ex(e): - if hasattr(e, "reasons") and e.reasons: - for ex in e.reasons: - _print_ex(ex) - else: - logger.exception(e) + return query_str.encode("utf8") class InsecureInterceptableContextFactory(ssl.ContextFactory): diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 83d6196d4a..3b756a7dc2 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py
@@ -12,21 +12,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -import urllib -from typing import List +import urllib.parse +from typing import List, Optional -from netaddr import AddrFormatError, IPAddress +from netaddr import AddrFormatError, IPAddress, IPSet from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import Agent, HTTPConnectionPool +from twisted.internet.interfaces import ( + IProtocolFactory, + IReactorCore, + IStreamClientEndpoint, +) +from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IAgentEndpointFactory +from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer +from synapse.crypto.context_factory import FederationPolicyForHTTPS +from synapse.http.client import BlacklistingAgentWrapper from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -44,30 +49,31 @@ class MatrixFederationAgent: Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) Args: - reactor (IReactor): twisted reactor to use for underlying requests + reactor: twisted reactor to use for underlying requests - tls_client_options_factory (FederationPolicyForHTTPS|None): + tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. - user_agent (bytes): + user_agent: The user agent header to use for federation requests. - _srv_resolver (SrvResolver|None): - SRVResolver impl to use for looking up SRV records. None to use a default - implementation. + _srv_resolver: + SrvResolver implementation to use for looking up SRV records. None + to use a default implementation. - _well_known_resolver (WellKnownResolver|None): + _well_known_resolver: WellKnownResolver to use to perform well-known lookups. None to use a default implementation. """ def __init__( self, - reactor, - tls_client_options_factory, - user_agent, - _srv_resolver=None, - _well_known_resolver=None, + reactor: IReactorCore, + tls_client_options_factory: Optional[FederationPolicyForHTTPS], + user_agent: bytes, + ip_blacklist: IPSet, + _srv_resolver: Optional[SrvResolver] = None, + _well_known_resolver: Optional[WellKnownResolver] = None, ): self._reactor = reactor self._clock = Clock(reactor) @@ -86,12 +92,18 @@ class MatrixFederationAgent: self.user_agent = user_agent if _well_known_resolver is None: + # Note that the name resolver has already been wrapped in a + # IPBlacklistingResolver by MatrixFederationHttpClient. _well_known_resolver = WellKnownResolver( self._reactor, - agent=Agent( + agent=BlacklistingAgentWrapper( + Agent( + self._reactor, + pool=self._pool, + contextFactory=tls_client_options_factory, + ), self._reactor, - pool=self._pool, - contextFactory=tls_client_options_factory, + ip_blacklist=ip_blacklist, ), user_agent=self.user_agent, ) @@ -99,15 +111,20 @@ class MatrixFederationAgent: self._well_known_resolver = _well_known_resolver @defer.inlineCallbacks - def request(self, method, uri, headers=None, bodyProducer=None): + def request( + self, + method: bytes, + uri: bytes, + headers: Optional[Headers] = None, + bodyProducer: Optional[IBodyProducer] = None, + ) -> defer.Deferred: """ Args: - method (bytes): HTTP method: GET/POST/etc - uri (bytes): Absolute URI to be retrieved - headers (twisted.web.http_headers.Headers|None): - HTTP headers to send with the request, or None to - send no extra headers. - bodyProducer (twisted.web.iweb.IBodyProducer|None): + method: HTTP method: GET/POST/etc + uri: Absolute URI to be retrieved + headers: + HTTP headers to send with the request, or None to send no extra headers. + bodyProducer: An object which can generate bytes to make up the body of this request (for example, the properly encoded contents of a file for a file upload). Or None if the request is to have @@ -123,6 +140,9 @@ class MatrixFederationAgent: # explicit port. parsed_uri = urllib.parse.urlparse(uri) + # There must be a valid hostname. + assert parsed_uri.hostname + # If this is a matrix:// URI check if the server has delegated matrix # traffic using well-known delegation. # @@ -179,7 +199,12 @@ class MatrixHostnameEndpointFactory: """Factory for MatrixHostnameEndpoint for parsing to an Agent. """ - def __init__(self, reactor, tls_client_options_factory, srv_resolver): + def __init__( + self, + reactor: IReactorCore, + tls_client_options_factory: Optional[FederationPolicyForHTTPS], + srv_resolver: Optional[SrvResolver], + ): self._reactor = reactor self._tls_client_options_factory = tls_client_options_factory @@ -203,15 +228,20 @@ class MatrixHostnameEndpoint: resolution (i.e. via SRV). Does not check for well-known delegation. Args: - reactor (IReactor) - tls_client_options_factory (ClientTLSOptionsFactory|None): + reactor: twisted reactor to use for underlying requests + tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. - srv_resolver (SrvResolver): The SRV resolver to use - parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting - to connect to. + srv_resolver: The SRV resolver to use + parsed_uri: The parsed URI that we're wanting to connect to. """ - def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri): + def __init__( + self, + reactor: IReactorCore, + tls_client_options_factory: Optional[FederationPolicyForHTTPS], + srv_resolver: SrvResolver, + parsed_uri: URI, + ): self._reactor = reactor self._parsed_uri = parsed_uri @@ -231,13 +261,13 @@ class MatrixHostnameEndpoint: self._srv_resolver = srv_resolver - def connect(self, protocol_factory): + def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: """Implements IStreamClientEndpoint interface """ return run_in_background(self._do_connect, protocol_factory) - async def _do_connect(self, protocol_factory): + async def _do_connect(self, protocol_factory: IProtocolFactory) -> None: first_exception = None server_list = await self._resolve_server() @@ -303,20 +333,20 @@ class MatrixHostnameEndpoint: return [Server(host, 8448)] -def _is_ip_literal(host): +def _is_ip_literal(host: bytes) -> bool: """Test if the given host name is either an IPv4 or IPv6 literal. Args: - host (bytes) + host: The host name to check Returns: - bool + True if the hostname is an IP address literal. """ - host = host.decode("ascii") + host_str = host.decode("ascii") try: - IPAddress(host) + IPAddress(host_str) return True except AddrFormatError: return False diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index a306faa267..b3b6dbcab0 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py
@@ -12,20 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import random import time +from io import BytesIO from typing import Callable, Dict, Optional, Tuple import attr from twisted.internet import defer -from twisted.web.client import RedirectAgent, readBody +from twisted.internet.interfaces import IReactorTime +from twisted.web.client import RedirectAgent from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers -from twisted.web.iweb import IResponse +from twisted.web.iweb import IAgent, IResponse +from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock, json_decoder from synapse.util.caches.ttlcache import TTLCache @@ -53,6 +55,9 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 # lower bound for .well-known cache period WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60 +# The maximum size (in bytes) to allow a well-known file to be. +WELL_KNOWN_MAX_SIZE = 50 * 1024 # 50 KiB + # Attempt to refetch a cached well-known N% of the TTL before it expires. # e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then # we'll start trying to refetch 1 minute before it expires. @@ -81,11 +86,11 @@ class WellKnownResolver: def __init__( self, - reactor, - agent, - user_agent, - well_known_cache=None, - had_well_known_cache=None, + reactor: IReactorTime, + agent: IAgent, + user_agent: bytes, + well_known_cache: Optional[TTLCache] = None, + had_well_known_cache: Optional[TTLCache] = None, ): self._reactor = reactor self._clock = Clock(reactor) @@ -127,7 +132,7 @@ class WellKnownResolver: with Measure(self._clock, "get_well_known"): result, cache_period = await self._fetch_well_known( server_name - ) # type: Tuple[Optional[bytes], float] + ) # type: Optional[bytes], float except _FetchWellKnownFailure as e: if prev_result and e.temporary: @@ -172,7 +177,7 @@ class WellKnownResolver: had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False) # We do this in two steps to differentiate between possibly transient - # errors (e.g. can't connect to host, 503 response) and more permenant + # errors (e.g. can't connect to host, 503 response) and more permanent # errors (such as getting a 404 response). response, body = await self._make_well_known_request( server_name, retry=had_valid_well_known @@ -229,6 +234,9 @@ class WellKnownResolver: server_name: name of the server, from the requested url retry: Whether to retry the request if it fails. + Raises: + _FetchWellKnownFailure if we fail to lookup a result + Returns: Returns the response object and body. Response may be a non-200 response. """ @@ -250,7 +258,11 @@ class WellKnownResolver: b"GET", uri, headers=Headers(headers) ) ) - body = await make_deferred_yieldable(readBody(response)) + body_stream = BytesIO() + await make_deferred_yieldable( + read_body_with_max_size(response, body_stream, WELL_KNOWN_MAX_SIZE) + ) + body = body_stream.getvalue() if 500 <= response.code < 600: raise Exception("Non-200 response %s" % (response.code,)) @@ -259,6 +271,15 @@ class WellKnownResolver: except defer.CancelledError: # Bail if we've been cancelled raise + except BodyExceededMaxSize: + # If the well-known file was too large, do not keep attempting + # to download it, but consider it a temporary error. + logger.warning( + "Requested .well-known file for %s is too large > %r bytes", + server_name.decode("ascii"), + WELL_KNOWN_MAX_SIZE, + ) + raise _FetchWellKnownFailure(temporary=True) except Exception as e: if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS: logger.info("Error fetching %s: %s", uri_str, e) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c23a4d7c0c..b261e078c4 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -17,23 +17,22 @@ import cgi import logging import random import sys -import urllib +import urllib.parse from io import BytesIO +from typing import Callable, Dict, List, Optional, Tuple, Union import attr import treq from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json -from zope.interface import implementer -from twisted.internet import defer, protocol +from twisted.internet import defer from twisted.internet.error import DNSLookupError -from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime +from twisted.internet.interfaces import IReactorTime from twisted.internet.task import _EPSILON, Cooperator -from twisted.web._newclient import ResponseDone from twisted.web.http_headers import Headers -from twisted.web.iweb import IResponse +from twisted.web.iweb import IBodyProducer, IResponse import synapse.metrics import synapse.util.retryutils @@ -45,7 +44,13 @@ from synapse.api.errors import ( SynapseError, ) from synapse.http import QuieterFileBodyProducer -from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver +from synapse.http.client import ( + BlacklistingAgentWrapper, + BlacklistingReactorWrapper, + BodyExceededMaxSize, + encode_query_args, + read_body_with_max_size, +) from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import ( @@ -54,6 +59,7 @@ from synapse.logging.opentracing import ( start_active_span, tags, ) +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -76,47 +82,44 @@ MAXINT = sys.maxsize _next_id = 1 +QueryArgs = Dict[str, Union[str, List[str]]] + + @attr.s(slots=True, frozen=True) class MatrixFederationRequest: - method = attr.ib() + method = attr.ib(type=str) """HTTP method - :type: str """ - path = attr.ib() + path = attr.ib(type=str) """HTTP path - :type: str """ - destination = attr.ib() + destination = attr.ib(type=str) """The remote server to send the HTTP request to. - :type: str""" + """ - json = attr.ib(default=None) + json = attr.ib(default=None, type=Optional[JsonDict]) """JSON to send in the body. - :type: dict|None """ - json_callback = attr.ib(default=None) + json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]]) """A callback to generate the JSON. - :type: func|None """ - query = attr.ib(default=None) + query = attr.ib(default=None, type=Optional[dict]) """Query arguments. - :type: dict|None """ - txn_id = attr.ib(default=None) + txn_id = attr.ib(default=None, type=Optional[str]) """Unique ID for this request (for logging) - :type: str|None """ uri = attr.ib(init=False, type=bytes) """The URI of this request """ - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: global _next_id txn_id = "%s-O-%s" % (self.method, _next_id) _next_id = (_next_id + 1) % (MAXINT - 1) @@ -136,7 +139,7 @@ class MatrixFederationRequest: ) object.__setattr__(self, "uri", uri) - def get_json(self): + def get_json(self) -> Optional[JsonDict]: if self.json_callback: return self.json_callback() return self.json @@ -148,7 +151,7 @@ async def _handle_json_response( request: MatrixFederationRequest, response: IResponse, start_ms: int, -): +) -> JsonDict: """ Reads the JSON body of a response, with a timeout @@ -160,7 +163,7 @@ async def _handle_json_response( start_ms: Timestamp when request was made Returns: - dict: parsed JSON response + The parsed JSON response """ try: check_content_type_is_json(response.headers) @@ -220,39 +223,28 @@ class MatrixFederationHttpClient: self.signing_key = hs.signing_key self.server_name = hs.hostname - real_reactor = hs.get_reactor() - # We need to use a DNS resolver which filters out blacklisted IP # addresses, to prevent DNS rebinding. - nameResolver = IPBlacklistingResolver( - real_reactor, None, hs.config.federation_ip_range_blacklist + self.reactor = BlacklistingReactorWrapper( + hs.get_reactor(), None, hs.config.federation_ip_range_blacklist ) - @implementer(IReactorPluggableNameResolver) - class Reactor: - def __getattr__(_self, attr): - if attr == "nameResolver": - return nameResolver - else: - return getattr(real_reactor, attr) - - self.reactor = Reactor() - user_agent = hs.version_string if hs.config.user_agent_suffix: user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) user_agent = user_agent.encode("ascii") self.agent = MatrixFederationAgent( - self.reactor, tls_client_options_factory, user_agent + self.reactor, + tls_client_options_factory, + user_agent, + hs.config.federation_ip_range_blacklist, ) # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( - self.agent, - self.reactor, - ip_blacklist=hs.config.federation_ip_range_blacklist, + self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist, ) self.clock = hs.get_clock() @@ -266,27 +258,29 @@ class MatrixFederationHttpClient: self._cooperator = Cooperator(scheduler=schedule) async def _send_request_with_optional_trailing_slash( - self, request, try_trailing_slash_on_400=False, **send_request_args - ): + self, + request: MatrixFederationRequest, + try_trailing_slash_on_400: bool = False, + **send_request_args + ) -> IResponse: """Wrapper for _send_request which can optionally retry the request upon receiving a combination of a 400 HTTP response code and a 'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3 due to #3622. Args: - request (MatrixFederationRequest): details of request to be sent - try_trailing_slash_on_400 (bool): Whether on receiving a 400 + request: details of request to be sent + try_trailing_slash_on_400: Whether on receiving a 400 'M_UNRECOGNIZED' from the server to retry the request with a trailing slash appended to the request path. - send_request_args (Dict): A dictionary of arguments to pass to - `_send_request()`. + send_request_args: A dictionary of arguments to pass to `_send_request()`. Raises: HttpResponseException: If we get an HTTP response code >= 300 (except 429). Returns: - Dict: Parsed JSON response body. + Parsed JSON response body. """ try: response = await self._send_request(request, **send_request_args) @@ -313,24 +307,26 @@ class MatrixFederationHttpClient: async def _send_request( self, - request, - retry_on_dns_fail=True, - timeout=None, - long_retries=False, - ignore_backoff=False, - backoff_on_404=False, - ): + request: MatrixFederationRequest, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + long_retries: bool = False, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + ) -> IResponse: """ Sends a request to the given server. Args: - request (MatrixFederationRequest): details of request to be sent + request: details of request to be sent + + retry_on_dns_fail: true if the request should be retied on DNS failures - timeout (int|None): number of milliseconds to wait for the response headers + timeout: number of milliseconds to wait for the response headers (including connecting to the server), *for each attempt*. 60s by default. - long_retries (bool): whether to use the long retry algorithm. + long_retries: whether to use the long retry algorithm. The regular retry algorithm makes 4 attempts, with intervals [0.5s, 1s, 2s]. @@ -346,14 +342,13 @@ class MatrixFederationHttpClient: NB: the long retry algorithm takes over 20 minutes to complete, with a default timeout of 60s! - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - backoff_on_404 (bool): Back off if we get a 404 + backoff_on_404: Back off if we get a 404 Returns: - twisted.web.client.Response: resolves with the HTTP - response object on success. + Resolves with the HTTP response object on success. Raises: HttpResponseException: If we get an HTTP response code >= 300 @@ -404,7 +399,7 @@ class MatrixFederationHttpClient: ) # Inject the span into the headers - headers_dict = {} + headers_dict = {} # type: Dict[bytes, List[bytes]] inject_active_span_byte_dict(headers_dict, request.destination) headers_dict[b"User-Agent"] = [self.version_string_bytes] @@ -435,7 +430,7 @@ class MatrixFederationHttpClient: data = encode_canonical_json(json) producer = QuieterFileBodyProducer( BytesIO(data), cooperator=self._cooperator - ) + ) # type: Optional[IBodyProducer] else: producer = None auth_headers = self.build_auth_headers( @@ -524,14 +519,16 @@ class MatrixFederationHttpClient: ) body = None - e = HttpResponseException(response.code, response_phrase, body) + exc = HttpResponseException( + response.code, response_phrase, body + ) # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException if response.code == 429: - raise RequestSendFailed(e, can_retry=True) from e + raise RequestSendFailed(exc, can_retry=True) from exc else: - raise e + raise exc break except RequestSendFailed as e: @@ -582,22 +579,27 @@ class MatrixFederationHttpClient: return response def build_auth_headers( - self, destination, method, url_bytes, content=None, destination_is=None - ): + self, + destination: Optional[bytes], + method: bytes, + url_bytes: bytes, + content: Optional[JsonDict] = None, + destination_is: Optional[bytes] = None, + ) -> List[bytes]: """ Builds the Authorization headers for a federation request Args: - destination (bytes|None): The desination homeserver of the request. + destination: The destination homeserver of the request. May be None if the destination is an identity server, in which case destination_is must be non-None. - method (bytes): The HTTP method of the request - url_bytes (bytes): The URI path of the request - content (object): The body of the request - destination_is (bytes): As 'destination', but if the destination is an + method: The HTTP method of the request + url_bytes: The URI path of the request + content: The body of the request + destination_is: As 'destination', but if the destination is an identity server Returns: - list[bytes]: a list of headers to be added as "Authorization:" headers + A list of headers to be added as "Authorization:" headers """ request = { "method": method.decode("ascii"), @@ -629,33 +631,32 @@ class MatrixFederationHttpClient: async def put_json( self, - destination, - path, - args={}, - data={}, - json_data_callback=None, - long_retries=False, - timeout=None, - ignore_backoff=False, - backoff_on_404=False, - try_trailing_slash_on_400=False, - ): - """ Sends the specifed json data using PUT + destination: str, + path: str, + args: Optional[QueryArgs] = None, + data: Optional[JsonDict] = None, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + try_trailing_slash_on_400: bool = False, + ) -> Union[JsonDict, list]: + """ Sends the specified json data using PUT Args: - destination (str): The remote server to send the HTTP request - to. - path (str): The HTTP path. - args (dict): query params - data (dict): A dict containing the data that will be used as + destination: The remote server to send the HTTP request to. + path: The HTTP path. + args: query params + data: A dict containing the data that will be used as the request body. This will be encoded as JSON. - json_data_callback (callable): A callable returning the dict to + json_data_callback: A callable returning the dict to use as the request body. - long_retries (bool): whether to use the long retry algorithm. See + long_retries: whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -663,19 +664,19 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - backoff_on_404 (bool): True if we should count a 404 response as + backoff_on_404: True if we should count a 404 response as a failure of the server (and should therefore back off future requests). - try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. This will be attempted before backing off if backing off has been enabled. Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -721,29 +722,28 @@ class MatrixFederationHttpClient: async def post_json( self, - destination, - path, - data={}, - long_retries=False, - timeout=None, - ignore_backoff=False, - args={}, - ): - """ Sends the specifed json data using POST + destination: str, + path: str, + data: Optional[JsonDict] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + args: Optional[QueryArgs] = None, + ) -> Union[JsonDict, list]: + """ Sends the specified json data using POST Args: - destination (str): The remote server to send the HTTP request - to. + destination: The remote server to send the HTTP request to. - path (str): The HTTP path. + path: The HTTP path. - data (dict): A dict containing the data that will be used as + data: A dict containing the data that will be used as the request body. This will be encoded as JSON. - long_retries (bool): whether to use the long retry algorithm. See + long_retries: whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -751,10 +751,10 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data and + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - args (dict): query params + args: query params Returns: dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -795,26 +795,25 @@ class MatrixFederationHttpClient: async def get_json( self, - destination, - path, - args=None, - retry_on_dns_fail=True, - timeout=None, - ignore_backoff=False, - try_trailing_slash_on_400=False, - ): + destination: str, + path: str, + args: Optional[QueryArgs] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + ) -> Union[JsonDict, list]: """ GETs some json from the given host homeserver and path Args: - destination (str): The remote server to send the HTTP request - to. + destination: The remote server to send the HTTP request to. - path (str): The HTTP path. + path: The HTTP path. - args (dict|None): A dictionary used to create query strings, defaults to + args: A dictionary used to create query strings, defaults to None. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -822,14 +821,14 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -870,24 +869,23 @@ class MatrixFederationHttpClient: async def delete_json( self, - destination, - path, - long_retries=False, - timeout=None, - ignore_backoff=False, - args={}, - ): + destination: str, + path: str, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + args: Optional[QueryArgs] = None, + ) -> Union[JsonDict, list]: """Send a DELETE request to the remote expecting some json response Args: - destination (str): The remote server to send the HTTP request - to. - path (str): The HTTP path. + destination: The remote server to send the HTTP request to. + path: The HTTP path. - long_retries (bool): whether to use the long retry algorithm. See + long_retries: whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -895,12 +893,12 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data and + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - args (dict): query params + args: query params Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -938,25 +936,25 @@ class MatrixFederationHttpClient: async def get_file( self, - destination, - path, + destination: str, + path: str, output_stream, - args={}, - retry_on_dns_fail=True, - max_size=None, - ignore_backoff=False, - ): + args: Optional[QueryArgs] = None, + retry_on_dns_fail: bool = True, + max_size: Optional[int] = None, + ignore_backoff: bool = False, + ) -> Tuple[int, Dict[bytes, List[bytes]]]: """GETs a file from a given homeserver Args: - destination (str): The remote server to send the HTTP request to. - path (str): The HTTP path to GET. - output_stream (file): File to write the response body to. - args (dict): Optional dictionary used to create the query string. - ignore_backoff (bool): true to ignore the historical backoff data + destination: The remote server to send the HTTP request to. + path: The HTTP path to GET. + output_stream: File to write the response body to. + args: Optional dictionary used to create the query string. + ignore_backoff: true to ignore the historical backoff data and try the request anyway. Returns: - tuple[int, dict]: Resolves with an (int,dict) tuple of + Resolves with an (int,dict) tuple of the file length and a dict of the response headers. Raises: @@ -980,9 +978,15 @@ class MatrixFederationHttpClient: headers = dict(response.headers.getAllRawHeaders()) try: - d = _readBodyToFile(response, output_stream, max_size) + d = read_body_with_max_size(response, output_stream, max_size) d.addTimeout(self.default_timeout, self.reactor) length = await make_deferred_yieldable(d) + except BodyExceededMaxSize: + msg = "Requested file is too large > %r bytes" % (max_size,) + logger.warning( + "{%s} [%s] %s", request.txn_id, request.destination, msg, + ) + SynapseError(502, msg, Codes.TOO_LARGE) except Exception as e: logger.warning( "{%s} [%s] Error reading response: %s", @@ -1004,40 +1008,6 @@ class MatrixFederationHttpClient: return (length, headers) -class _ReadBodyToFileProtocol(protocol.Protocol): - def __init__(self, stream, deferred, max_size): - self.stream = stream - self.deferred = deferred - self.length = 0 - self.max_size = max_size - - def dataReceived(self, data): - self.stream.write(data) - self.length += len(data) - if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback( - SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - ) - ) - self.deferred = defer.Deferred() - self.transport.loseConnection() - - def connectionLost(self, reason): - if reason.check(ResponseDone): - self.deferred.callback(self.length) - else: - self.deferred.errback(reason) - - -def _readBodyToFile(response, stream, max_size): - d = defer.Deferred() - response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) - return d - - def _flatten_response_never_received(e): if hasattr(e, "reasons"): reasons = ", ".join( @@ -1049,13 +1019,13 @@ def _flatten_response_never_received(e): return repr(e) -def check_content_type_is_json(headers): +def check_content_type_is_json(headers: Headers) -> None: """ Check that a set of HTTP headers have a Content-Type header, and that it is application/json. Args: - headers (twisted.web.http_headers.Headers): headers to check + headers: headers to check Raises: RequestSendFailed: if the Content-Type header is missing or isn't JSON @@ -1063,27 +1033,18 @@ def check_content_type_is_json(headers): """ c_type = headers.getRawHeaders(b"Content-Type") if c_type is None: - raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False) + raise RequestSendFailed( + RuntimeError("No Content-Type header received from remote server"), + can_retry=False, + ) c_type = c_type[0].decode("ascii") # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": raise RequestSendFailed( - RuntimeError("Content-Type not application/json: was '%s'" % c_type), + RuntimeError( + "Remote server sent Content-Type header of '%s', not 'application/json'" + % c_type, + ), can_retry=False, ) - - -def encode_query_args(args): - if args is None: - return b"" - - encoded_args = {} - for k, vs in args.items(): - if isinstance(vs, str): - vs = [vs] - encoded_args[k] = [v.encode("UTF-8") for v in vs] - - query_bytes = urllib.parse.urlencode(encoded_args, True) - - return query_bytes.encode("utf8") diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index e32d3f43e0..b730d2c634 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py
@@ -39,6 +39,10 @@ class ProxyAgent(_AgentBase): reactor: twisted reactor to place outgoing connections. + proxy_reactor: twisted reactor to use for connections to the proxy server + reactor might have some blacklisting applied (i.e. for DNS queries), + but we need unblocked access to the proxy. + contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the verification parameters of OpenSSL. The default is to use a `BrowserLikePolicyForHTTPS`, so unless you have special @@ -59,6 +63,7 @@ class ProxyAgent(_AgentBase): def __init__( self, reactor, + proxy_reactor=None, contextFactory=BrowserLikePolicyForHTTPS(), connectTimeout=None, bindAddress=None, @@ -68,6 +73,11 @@ class ProxyAgent(_AgentBase): ): _AgentBase.__init__(self, reactor, pool) + if proxy_reactor is None: + self.proxy_reactor = reactor + else: + self.proxy_reactor = proxy_reactor + self._endpoint_kwargs = {} if connectTimeout is not None: self._endpoint_kwargs["timeout"] = connectTimeout @@ -75,11 +85,11 @@ class ProxyAgent(_AgentBase): self._endpoint_kwargs["bindAddress"] = bindAddress self.http_proxy_endpoint = _http_proxy_endpoint( - http_proxy, reactor, **self._endpoint_kwargs + http_proxy, self.proxy_reactor, **self._endpoint_kwargs ) self.https_proxy_endpoint = _http_proxy_endpoint( - https_proxy, reactor, **self._endpoint_kwargs + https_proxy, self.proxy_reactor, **self._endpoint_kwargs ) self._policy_for_https = contextFactory @@ -137,7 +147,7 @@ class ProxyAgent(_AgentBase): request_path = uri elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: endpoint = HTTPConnectProxyEndpoint( - self._reactor, + self.proxy_reactor, self.https_proxy_endpoint, parsed_uri.host, parsed_uri.port, diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index cd94e789e8..7c5defec82 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py
@@ -109,7 +109,7 @@ in_flight_requests_db_sched_duration = Counter( # The set of all in flight requests, set[RequestMetrics] _in_flight_requests = set() -# Protects the _in_flight_requests set from concurrent accesss +# Protects the _in_flight_requests set from concurrent access _in_flight_requests_lock = threading.Lock() diff --git a/synapse/http/server.py b/synapse/http/server.py
index 996a31a9ec..e464bfe6c7 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -25,7 +25,7 @@ from io import BytesIO from typing import Any, Callable, Dict, Iterator, List, Tuple, Union import jinja2 -from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json +from canonicaljson import iterencode_canonical_json from zope.interface import implementer from twisted.internet import defer, interfaces @@ -35,8 +35,6 @@ from twisted.web.server import NOT_DONE_YET, Request from twisted.web.static import File, NoRangeStaticProducer from twisted.web.util import redirectTo -import synapse.events -import synapse.metrics from synapse.api.errors import ( CodeMessageException, Codes, @@ -96,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: pass else: respond_with_json( - request, - error_code, - error_dict, - send_cors=True, - pretty_print=_request_user_agent_is_curl(request), + request, error_code, error_dict, send_cors=True, ) @@ -182,7 +176,7 @@ class HttpServer: """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. - If the regex contains groups these gets passed to the calback via + If the regex contains groups these gets passed to the callback via an unpacked tuple. Args: @@ -241,7 +235,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): async def _async_render(self, request: Request): """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if - no appropriate method exists. Can be overriden in sub classes for + no appropriate method exists. Can be overridden in sub classes for different routing. """ # Treat HEAD requests as GET requests. @@ -257,7 +251,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): callback_return = await raw_callback_return else: - callback_return = raw_callback_return + callback_return = raw_callback_return # type: ignore return callback_return @@ -281,6 +275,10 @@ class DirectServeJsonResource(_AsyncResource): formatting responses and errors as JSON. """ + def __init__(self, canonical_json=False, extract_context=False): + super().__init__(extract_context) + self.canonical_json = canonical_json + def _send_response( self, request: Request, code: int, response_object: Any, ): @@ -292,7 +290,6 @@ class DirectServeJsonResource(_AsyncResource): code, response_object, send_cors=True, - pretty_print=_request_user_agent_is_curl(request), canonical_json=self.canonical_json, ) @@ -325,9 +322,7 @@ class JsonResource(DirectServeJsonResource): ) def __init__(self, hs, canonical_json=True, extract_context=False): - super().__init__(extract_context) - - self.canonical_json = canonical_json + super().__init__(canonical_json, extract_context) self.clock = hs.get_clock() self.path_regexs = {} self.hs = hs @@ -386,7 +381,7 @@ class JsonResource(DirectServeJsonResource): async def _async_render(self, request): callback, servlet_classname, group_dict = self._get_handler_for_request(request) - # Make sure we have an appopriate name for this handler in prometheus + # Make sure we have an appropriate name for this handler in prometheus # (rather than the default of JsonResource). request.request_metrics.name = servlet_classname @@ -406,7 +401,7 @@ class JsonResource(DirectServeJsonResource): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): callback_return = await raw_callback_return else: - callback_return = raw_callback_return + callback_return = raw_callback_return # type: ignore return callback_return @@ -589,7 +584,6 @@ def respond_with_json( code: int, json_object: Any, send_cors: bool = False, - pretty_print: bool = False, canonical_json: bool = True, ): """Sends encoded JSON in response to the given request. @@ -600,8 +594,6 @@ def respond_with_json( json_object: The object to serialize to JSON. send_cors: Whether to send Cross-Origin Resource Sharing headers https://fetch.spec.whatwg.org/#http-cors-protocol - pretty_print: Whether to include indentation and line-breaks in the - resulting JSON bytes. canonical_json: Whether to use the canonicaljson algorithm when encoding the JSON bytes. @@ -617,13 +609,10 @@ def respond_with_json( ) return None - if pretty_print: - encoder = iterencode_pretty_printed_json + if canonical_json: + encoder = iterencode_canonical_json else: - if canonical_json or synapse.events.USE_FROZEN_DICTS: - encoder = iterencode_canonical_json - else: - encoder = _encode_json_bytes + encoder = _encode_json_bytes request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") @@ -651,6 +640,11 @@ def respond_with_json_bytes( Returns: twisted.web.server.NOT_DONE_YET if the request is still active. """ + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") @@ -682,7 +676,7 @@ def set_cors_headers(request: Request): ) request.setHeader( b"Access-Control-Allow-Headers", - b"Origin, X-Requested-With, Content-Type, Accept, Authorization", + b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date", ) @@ -756,11 +750,3 @@ def finish_request(request: Request): request.finish() except RuntimeError as e: logger.info("Connection disconnected before response was written: %r", e) - - -def _request_user_agent_is_curl(request: Request) -> bool: - user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[]) - for user_agent in user_agents: - if b"curl" in user_agent: - return True - return False diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fd90ba7828..b361b7cbaf 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -272,7 +272,6 @@ class RestServlet: on_PUT on_POST on_DELETE - on_OPTIONS Automatically handles turning CodeMessageExceptions thrown by these methods into the appropriate HTTP response. @@ -283,7 +282,7 @@ class RestServlet: if hasattr(self, "PATTERNS"): patterns = self.PATTERNS - for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): + for method in ("GET", "PUT", "POST", "DELETE"): if hasattr(self, "on_%s" % (method,)): servlet_classname = self.__class__.__name__ method_handler = getattr(self, "on_%s" % (method,)) diff --git a/synapse/http/site.py b/synapse/http/site.py
index 6e79b47828..5a5790831b 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Optional +from typing import Optional, Union from twisted.python.failure import Failure from twisted.web.server import Request, Site @@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig from synapse.http import redact_uri from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.types import Requester logger = logging.getLogger(__name__) @@ -54,9 +55,12 @@ class SynapseRequest(Request): Request.__init__(self, channel, *args, **kw) self.site = channel.site self._channel = channel # this is used by the tests - self.authenticated_entity = None self.start_time = 0.0 + # The requester, if authenticated. For federation requests this is the + # server name, for client requests this is the Requester object. + self.requester = None # type: Optional[Union[Requester, str]] + # we can't yet create the logcontext, as we don't know the method. self.logcontext = None # type: Optional[LoggingContext] @@ -109,8 +113,14 @@ class SynapseRequest(Request): method = self.method.decode("ascii") return method - def get_user_agent(self): - return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] + def get_user_agent(self, default: str) -> str: + """Return the last User-Agent header, or the given default. + """ + user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] + if user_agent is None: + return default + + return user_agent.decode("ascii", "replace") def render(self, resrc): # this is called once a Resource has been found to serve the request; in our @@ -118,8 +128,7 @@ class SynapseRequest(Request): # create a LogContext for this request request_id = self.get_request_id() - logcontext = self.logcontext = LoggingContext(request_id) - logcontext.request = request_id + self.logcontext = LoggingContext(request_id, request=request_id) # override the Server header which is set by twisted self.setHeader("Server", self.site.server_version_string) @@ -161,7 +170,9 @@ class SynapseRequest(Request): yield except Exception: # this should already have been caught, and sent back to the client as a 500. - logger.exception("Asynchronous messge handler raised an uncaught exception") + logger.exception( + "Asynchronous message handler raised an uncaught exception" + ) finally: # the request handler has finished its work and either sent the whole response # back, or handed over responsibility to a Producer. @@ -263,22 +274,30 @@ class SynapseRequest(Request): # to the client (nb may be negative) response_send_time = self.finish_time - self._processing_finished_time - # need to decode as it could be raw utf-8 bytes - # from a IDN servname in an auth header - authenticated_entity = self.authenticated_entity - if authenticated_entity is not None and isinstance(authenticated_entity, bytes): - authenticated_entity = authenticated_entity.decode("utf-8", "replace") + # Convert the requester into a string that we can log + authenticated_entity = None + if isinstance(self.requester, str): + authenticated_entity = self.requester + elif isinstance(self.requester, Requester): + authenticated_entity = self.requester.authenticated_entity + + # If this is a request where the target user doesn't match the user who + # authenticated (e.g. and admin is puppetting a user) then we log both. + if self.requester.user.to_string() != authenticated_entity: + authenticated_entity = "{},{}".format( + authenticated_entity, self.requester.user.to_string(), + ) + elif self.requester is not None: + # This shouldn't happen, but we log it so we don't lose information + # and can see that we're doing something wrong. + authenticated_entity = repr(self.requester) # type: ignore[unreachable] # ...or could be raw utf-8 bytes in the User-Agent header. # N.B. if you don't do this, the logger explodes cryptically # with maximum recursion trying to log errors about # the charset problem. # c.f. https://github.com/matrix-org/synapse/issues/3471 - user_agent = self.get_user_agent() - if user_agent is not None: - user_agent = user_agent.decode("utf-8", "replace") - else: - user_agent = "-" + user_agent = self.get_user_agent("-") code = str(self.code) if not self.finished: diff --git a/synapse/logging/__init__.py b/synapse/logging/__init__.py
index e69de29bb2..b28b7b2ef7 100644 --- a/synapse/logging/__init__.py +++ b/synapse/logging/__init__.py
@@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These are imported to allow for nicer logging configuration files. +from synapse.logging._remote import RemoteHandler +from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter + +__all__ = ["RemoteHandler", "JsonFormatter", "TerseJsonFormatter"] diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py new file mode 100644
index 0000000000..fb937b3f28 --- /dev/null +++ b/synapse/logging/_remote.py
@@ -0,0 +1,241 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +import traceback +from collections import deque +from ipaddress import IPv4Address, IPv6Address, ip_address +from math import floor +from typing import Callable, Optional + +import attr +from typing_extensions import Deque +from zope.interface import implementer + +from twisted.application.internet import ClientService +from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.endpoints import ( + HostnameEndpoint, + TCP4ClientEndpoint, + TCP6ClientEndpoint, +) +from twisted.internet.interfaces import IPushProducer, ITransport +from twisted.internet.protocol import Factory, Protocol +from twisted.python.failure import Failure + +logger = logging.getLogger(__name__) + + +@attr.s +@implementer(IPushProducer) +class LogProducer: + """ + An IPushProducer that writes logs from its buffer to its transport when it + is resumed. + + Args: + buffer: Log buffer to read logs from. + transport: Transport to write to. + format: A callable to format the log record to a string. + """ + + transport = attr.ib(type=ITransport) + _format = attr.ib(type=Callable[[logging.LogRecord], str]) + _buffer = attr.ib(type=deque) + _paused = attr.ib(default=False, type=bool, init=False) + + def pauseProducing(self): + self._paused = True + + def stopProducing(self): + self._paused = True + self._buffer = deque() + + def resumeProducing(self): + # If we're already producing, nothing to do. + self._paused = False + + # Loop until paused. + while self._paused is False and (self._buffer and self.transport.connected): + try: + # Request the next record and format it. + record = self._buffer.popleft() + msg = self._format(record) + + # Send it as a new line over the transport. + self.transport.write(msg.encode("utf8")) + self.transport.write(b"\n") + except Exception: + # Something has gone wrong writing to the transport -- log it + # and break out of the while. + traceback.print_exc(file=sys.__stderr__) + break + + +class RemoteHandler(logging.Handler): + """ + An logging handler that writes logs to a TCP target. + + Args: + host: The host of the logging target. + port: The logging target's port. + maximum_buffer: The maximum buffer size. + """ + + def __init__( + self, + host: str, + port: int, + maximum_buffer: int = 1000, + level=logging.NOTSET, + _reactor=None, + ): + super().__init__(level=level) + self.host = host + self.port = port + self.maximum_buffer = maximum_buffer + + self._buffer = deque() # type: Deque[logging.LogRecord] + self._connection_waiter = None # type: Optional[Deferred] + self._producer = None # type: Optional[LogProducer] + + # Connect without DNS lookups if it's a direct IP. + if _reactor is None: + from twisted.internet import reactor + + _reactor = reactor + + try: + ip = ip_address(self.host) + if isinstance(ip, IPv4Address): + endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port) + elif isinstance(ip, IPv6Address): + endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port) + else: + raise ValueError("Unknown IP address provided: %s" % (self.host,)) + except ValueError: + endpoint = HostnameEndpoint(_reactor, self.host, self.port) + + factory = Factory.forProtocol(Protocol) + self._service = ClientService(endpoint, factory, clock=_reactor) + self._service.startService() + self._stopping = False + self._connect() + + def close(self): + self._stopping = True + self._service.stopService() + + def _connect(self) -> None: + """ + Triggers an attempt to connect then write to the remote if not already writing. + """ + # Do not attempt to open multiple connections. + if self._connection_waiter: + return + + self._connection_waiter = self._service.whenConnected(failAfterFailures=1) + + def fail(failure: Failure) -> None: + # If the Deferred was cancelled (e.g. during shutdown) do not try to + # reconnect (this will cause an infinite loop of errors). + if failure.check(CancelledError) and self._stopping: + return + + # For a different error, print the traceback and re-connect. + failure.printTraceback(file=sys.__stderr__) + self._connection_waiter = None + self._connect() + + def writer(result: Protocol) -> None: + # We have a connection. If we already have a producer, and its + # transport is the same, just trigger a resumeProducing. + if self._producer and result.transport is self._producer.transport: + self._producer.resumeProducing() + self._connection_waiter = None + return + + # If the producer is still producing, stop it. + if self._producer: + self._producer.stopProducing() + + # Make a new producer and start it. + self._producer = LogProducer( + buffer=self._buffer, transport=result.transport, format=self.format, + ) + result.transport.registerProducer(self._producer, True) + self._producer.resumeProducing() + self._connection_waiter = None + + self._connection_waiter.addCallbacks(writer, fail) + + def _handle_pressure(self) -> None: + """ + Handle backpressure by shedding records. + + The buffer will, in this order, until the buffer is below the maximum: + - Shed DEBUG records. + - Shed INFO records. + - Shed the middle 50% of the records. + """ + if len(self._buffer) <= self.maximum_buffer: + return + + # Strip out DEBUGs + self._buffer = deque( + filter(lambda record: record.levelno > logging.DEBUG, self._buffer) + ) + + if len(self._buffer) <= self.maximum_buffer: + return + + # Strip out INFOs + self._buffer = deque( + filter(lambda record: record.levelno > logging.INFO, self._buffer) + ) + + if len(self._buffer) <= self.maximum_buffer: + return + + # Cut the middle entries out + buffer_split = floor(self.maximum_buffer / 2) + + old_buffer = self._buffer + self._buffer = deque() + + for i in range(buffer_split): + self._buffer.append(old_buffer.popleft()) + + end_buffer = [] + for i in range(buffer_split): + end_buffer.append(old_buffer.pop()) + + self._buffer.extend(reversed(end_buffer)) + + def emit(self, record: logging.LogRecord) -> None: + self._buffer.append(record) + + # Handle backpressure, if it exists. + try: + self._handle_pressure() + except Exception: + # If handling backpressure fails, clear the buffer and log the + # exception. + self._buffer.clear() + logger.warning("Failed clearing backpressure") + + # Try and write immediately. + self._connect() diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 144506c8f2..14d9c104c2 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py
@@ -12,146 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import logging import os.path -import sys -import typing -import warnings -from typing import List - -import attr -from constantly import NamedConstant, Names, ValueConstant, Values -from zope.interface import implementer +from typing import Any, Dict, Generator, Optional, Tuple -from twisted.logger import ( - FileLogObserver, - FilteringLogObserver, - ILogObserver, - LogBeginner, - Logger, - LogLevel, - LogLevelFilterPredicate, - LogPublisher, - eventAsText, - jsonFileLogObserver, -) +from constantly import NamedConstant, Names from synapse.config._base import ConfigError -from synapse.logging._terse_json import ( - TerseJSONToConsoleLogObserver, - TerseJSONToTCPLogObserver, -) -from synapse.logging.context import current_context - - -def stdlib_log_level_to_twisted(level: str) -> LogLevel: - """ - Convert a stdlib log level to Twisted's log level. - """ - lvl = level.lower().replace("warning", "warn") - return LogLevel.levelWithName(lvl) - - -@attr.s -@implementer(ILogObserver) -class LogContextObserver: - """ - An ILogObserver which adds Synapse-specific log context information. - - Attributes: - observer (ILogObserver): The target parent observer. - """ - - observer = attr.ib() - - def __call__(self, event: dict) -> None: - """ - Consume a log event and emit it to the parent observer after filtering - and adding log context information. - - Args: - event (dict) - """ - # Filter out some useless events that Twisted outputs - if "log_text" in event: - if event["log_text"].startswith("DNSDatagramProtocol starting on "): - return - - if event["log_text"].startswith("(UDP Port "): - return - - if event["log_text"].startswith("Timing out client") or event[ - "log_format" - ].startswith("Timing out client"): - return - - context = current_context() - - # Copy the context information to the log event. - if context is not None: - context.copy_to_twisted_log_entry(event) - else: - # If there's no logging context, not even the root one, we might be - # starting up or it might be from non-Synapse code. Log it as if it - # came from the root logger. - event["request"] = None - event["scope"] = None - - self.observer(event) - - -class PythonStdlibToTwistedLogger(logging.Handler): - """ - Transform a Python stdlib log message into a Twisted one. - """ - - def __init__(self, observer, *args, **kwargs): - """ - Args: - observer (ILogObserver): A Twisted logging observer. - *args, **kwargs: Args/kwargs to be passed to logging.Handler. - """ - self.observer = observer - super().__init__(*args, **kwargs) - - def emit(self, record: logging.LogRecord) -> None: - """ - Emit a record to Twisted's observer. - - Args: - record (logging.LogRecord) - """ - - self.observer( - { - "log_time": record.created, - "log_text": record.getMessage(), - "log_format": "{log_text}", - "log_namespace": record.name, - "log_level": stdlib_log_level_to_twisted(record.levelname), - } - ) - - -def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver: - """ - A log observer that formats events like the traditional log formatter and - sends them to `outFile`. - - Args: - outFile (file object): The file object to write to. - """ - - def formatEvent(_event: dict) -> str: - event = dict(_event) - event["log_level"] = event["log_level"].name.upper() - event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + ( - event.get("log_format", "{log_text}") or "{log_text}" - ) - return eventAsText(event, includeSystem=False) + "\n" - - return FileLogObserver(outFile, formatEvent) class DrainType(Names): @@ -163,30 +29,12 @@ class DrainType(Names): NETWORK_JSON_TERSE = NamedConstant() -class OutputPipeType(Values): - stdout = ValueConstant(sys.__stdout__) - stderr = ValueConstant(sys.__stderr__) - - -@attr.s -class DrainConfiguration: - name = attr.ib() - type = attr.ib() - location = attr.ib() - options = attr.ib(default=None) - - -@attr.s -class NetworkJSONTerseOptions: - maximum_buffer = attr.ib(type=int) - - -DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}} +DEFAULT_LOGGERS = {"synapse": {"level": "info"}} def parse_drain_configs( drains: dict, -) -> typing.Generator[DrainConfiguration, None, None]: +) -> Generator[Tuple[str, Dict[str, Any]], None, None]: """ Parse the drain configurations. @@ -194,11 +42,12 @@ def parse_drain_configs( drains (dict): A list of drain configurations. Yields: - DrainConfiguration instances. + dict instances representing a logging handler. Raises: ConfigError: If any of the drain configuration items are invalid. """ + for name, config in drains.items(): if "type" not in config: raise ConfigError("Logging drains require a 'type' key.") @@ -210,6 +59,18 @@ def parse_drain_configs( "%s is not a known logging drain type." % (config["type"],) ) + # Either use the default formatter or the tersejson one. + if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,): + formatter = "json" # type: Optional[str] + elif logging_type in ( + DrainType.CONSOLE_JSON_TERSE, + DrainType.NETWORK_JSON_TERSE, + ): + formatter = "tersejson" + else: + # A formatter of None implies using the default formatter. + formatter = None + if logging_type in [ DrainType.CONSOLE, DrainType.CONSOLE_JSON, @@ -225,9 +86,11 @@ def parse_drain_configs( % (logging_type,) ) - pipe = OutputPipeType.lookupByName(location).value - - yield DrainConfiguration(name=name, type=logging_type, location=pipe) + yield name, { + "class": "logging.StreamHandler", + "formatter": formatter, + "stream": "ext://sys." + location, + } elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]: if "location" not in config: @@ -241,18 +104,25 @@ def parse_drain_configs( "File paths need to be absolute, '%s' is a relative path" % (location,) ) - yield DrainConfiguration(name=name, type=logging_type, location=location) + + yield name, { + "class": "logging.FileHandler", + "formatter": formatter, + "filename": location, + } elif logging_type in [DrainType.NETWORK_JSON_TERSE]: host = config.get("host") port = config.get("port") maximum_buffer = config.get("maximum_buffer", 1000) - yield DrainConfiguration( - name=name, - type=logging_type, - location=(host, port), - options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer), - ) + + yield name, { + "class": "synapse.logging.RemoteHandler", + "formatter": formatter, + "host": host, + "port": port, + "maximum_buffer": maximum_buffer, + } else: raise ConfigError( @@ -261,126 +131,29 @@ def parse_drain_configs( ) -class StoppableLogPublisher(LogPublisher): - """ - A log publisher that can tell its observers to shut down any external - communications. +def setup_structured_logging(log_config: dict,) -> dict: """ - - def stop(self): - for obs in self._observers: - if hasattr(obs, "stop"): - obs.stop() - - -def setup_structured_logging( - hs, - config, - log_config: dict, - logBeginner: LogBeginner, - redirect_stdlib_logging: bool = True, -) -> LogPublisher: + Convert a legacy structured logging configuration (from Synapse < v1.23.0) + to one compatible with the new standard library handlers. """ - Set up Twisted's structured logging system. - - Args: - hs: The homeserver to use. - config (HomeserverConfig): The configuration of the Synapse homeserver. - log_config (dict): The log configuration to use. - """ - if config.no_redirect_stdio: - raise ConfigError( - "no_redirect_stdio cannot be defined using structured logging." - ) - - logger = Logger() - if "drains" not in log_config: raise ConfigError("The logging configuration requires a list of drains.") - observers = [] # type: List[ILogObserver] - - for observer in parse_drain_configs(log_config["drains"]): - # Pipe drains - if observer.type == DrainType.CONSOLE: - logger.debug( - "Starting up the {name} console logger drain", name=observer.name - ) - observers.append(SynapseFileLogObserver(observer.location)) - elif observer.type == DrainType.CONSOLE_JSON: - logger.debug( - "Starting up the {name} JSON console logger drain", name=observer.name - ) - observers.append(jsonFileLogObserver(observer.location)) - elif observer.type == DrainType.CONSOLE_JSON_TERSE: - logger.debug( - "Starting up the {name} terse JSON console logger drain", - name=observer.name, - ) - observers.append( - TerseJSONToConsoleLogObserver(observer.location, metadata={}) - ) - - # File drains - elif observer.type == DrainType.FILE: - logger.debug("Starting up the {name} file logger drain", name=observer.name) - log_file = open(observer.location, "at", buffering=1, encoding="utf8") - observers.append(SynapseFileLogObserver(log_file)) - elif observer.type == DrainType.FILE_JSON: - logger.debug( - "Starting up the {name} JSON file logger drain", name=observer.name - ) - log_file = open(observer.location, "at", buffering=1, encoding="utf8") - observers.append(jsonFileLogObserver(log_file)) - - elif observer.type == DrainType.NETWORK_JSON_TERSE: - metadata = {"server_name": hs.config.server_name} - log_observer = TerseJSONToTCPLogObserver( - hs=hs, - host=observer.location[0], - port=observer.location[1], - metadata=metadata, - maximum_buffer=observer.options.maximum_buffer, - ) - log_observer.start() - observers.append(log_observer) - else: - # We should never get here, but, just in case, throw an error. - raise ConfigError("%s drain type cannot be configured" % (observer.type,)) - - publisher = StoppableLogPublisher(*observers) - log_filter = LogLevelFilterPredicate() - - for namespace, namespace_config in log_config.get( - "loggers", DEFAULT_LOGGERS - ).items(): - # Set the log level for twisted.logger.Logger namespaces - log_filter.setLogLevelForNamespace( - namespace, - stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")), - ) - - # Also set the log levels for the stdlib logger namespaces, to prevent - # them getting to PythonStdlibToTwistedLogger and having to be formatted - if "level" in namespace_config: - logging.getLogger(namespace).setLevel(namespace_config.get("level")) - - f = FilteringLogObserver(publisher, [log_filter]) - lco = LogContextObserver(f) - - if redirect_stdlib_logging: - stuff_into_twisted = PythonStdlibToTwistedLogger(lco) - stdliblogger = logging.getLogger() - stdliblogger.addHandler(stuff_into_twisted) - - # Always redirect standard I/O, otherwise other logging outputs might miss - # it. - logBeginner.beginLoggingTo([lco], redirectStandardIO=True) + new_config = { + "version": 1, + "formatters": { + "json": {"class": "synapse.logging.JsonFormatter"}, + "tersejson": {"class": "synapse.logging.TerseJsonFormatter"}, + }, + "handlers": {}, + "loggers": log_config.get("loggers", DEFAULT_LOGGERS), + "root": {"handlers": []}, + } - return publisher + for handler_name, handler in parse_drain_configs(log_config["drains"]): + new_config["handlers"][handler_name] = handler + # Add each handler to the root logger. + new_config["root"]["handlers"].append(handler_name) -def reload_structured_logging(*args, log_config=None) -> None: - warnings.warn( - "Currently the structured logging system can not be reloaded, doing nothing" - ) + return new_config diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 1b8916cfa2..2fbf5549a1 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py
@@ -16,314 +16,65 @@ """ Log formatters that output terse JSON. """ - import json -import sys -import traceback -from collections import deque -from ipaddress import IPv4Address, IPv6Address, ip_address -from math import floor -from typing import IO, Optional - -import attr -from zope.interface import implementer - -from twisted.application.internet import ClientService -from twisted.internet.defer import Deferred -from twisted.internet.endpoints import ( - HostnameEndpoint, - TCP4ClientEndpoint, - TCP6ClientEndpoint, -) -from twisted.internet.interfaces import IPushProducer, ITransport -from twisted.internet.protocol import Factory, Protocol -from twisted.logger import FileLogObserver, ILogObserver, Logger +import logging _encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":")) - -def flatten_event(event: dict, metadata: dict, include_time: bool = False): - """ - Flatten a Twisted logging event to an dictionary capable of being sent - as a log event to a logging aggregation system. - - The format is vastly simplified and is not designed to be a "human readable - string" in the sense that traditional logs are. Instead, the structure is - optimised for searchability and filtering, with human-understandable log - keys. - - Args: - event (dict): The Twisted logging event we are flattening. - metadata (dict): Additional data to include with each log message. This - can be information like the server name. Since the target log - consumer does not know who we are other than by host IP, this - allows us to forward through static information. - include_time (bool): Should we include the `time` key? If False, the - event time is stripped from the event. - """ - new_event = {} - - # If it's a failure, make the new event's log_failure be the traceback text. - if "log_failure" in event: - new_event["log_failure"] = event["log_failure"].getTraceback() - - # If it's a warning, copy over a string representation of the warning. - if "warning" in event: - new_event["warning"] = str(event["warning"]) - - # Stdlib logging events have "log_text" as their human-readable portion, - # Twisted ones have "log_format". For now, include the log_format, so that - # context only given in the log format (e.g. what is being logged) is - # available. - if "log_text" in event: - new_event["log"] = event["log_text"] - else: - new_event["log"] = event["log_format"] - - # We want to include the timestamp when forwarding over the network, but - # exclude it when we are writing to stdout. This is because the log ingester - # (e.g. logstash, fluentd) can add its own timestamp. - if include_time: - new_event["time"] = round(event["log_time"], 2) - - # Convert the log level to a textual representation. - new_event["level"] = event["log_level"].name.upper() - - # Ignore these keys, and do not transfer them over to the new log object. - # They are either useless (isError), transferred manually above (log_time, - # log_level, etc), or contain Python objects which are not useful for output - # (log_logger, log_source). - keys_to_delete = [ - "isError", - "log_failure", - "log_format", - "log_level", - "log_logger", - "log_source", - "log_system", - "log_time", - "log_text", - "observer", - "warning", - ] - - # If it's from the Twisted legacy logger (twisted.python.log), it adds some - # more keys we want to purge. - if event.get("log_namespace") == "log_legacy": - keys_to_delete.extend(["message", "system", "time"]) - - # Rather than modify the dictionary in place, construct a new one with only - # the content we want. The original event should be considered 'frozen'. - for key in event.keys(): - - if key in keys_to_delete: - continue - - if isinstance(event[key], (str, int, bool, float)) or event[key] is None: - # If it's a plain type, include it as is. - new_event[key] = event[key] - else: - # If it's not one of those basic types, write out a string - # representation. This should probably be a warning in development, - # so that we are sure we are only outputting useful data. - new_event[key] = str(event[key]) - - # Add the metadata information to the event (e.g. the server_name). - new_event.update(metadata) - - return new_event - - -def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver: - """ - A log observer that formats events to a flattened JSON representation. - - Args: - outFile: The file object to write to. - metadata: Metadata to be added to each log object. - """ - - def formatEvent(_event: dict) -> str: - flattened = flatten_event(_event, metadata) - return _encoder.encode(flattened) + "\n" - - return FileLogObserver(outFile, formatEvent) - - -@attr.s -@implementer(IPushProducer) -class LogProducer: - """ - An IPushProducer that writes logs from its buffer to its transport when it - is resumed. - - Args: - buffer: Log buffer to read logs from. - transport: Transport to write to. - """ - - transport = attr.ib(type=ITransport) - _buffer = attr.ib(type=deque) - _paused = attr.ib(default=False, type=bool, init=False) - - def pauseProducing(self): - self._paused = True - - def stopProducing(self): - self._paused = True - self._buffer = deque() - - def resumeProducing(self): - self._paused = False - - while self._paused is False and (self._buffer and self.transport.connected): - try: - event = self._buffer.popleft() - self.transport.write(_encoder.encode(event).encode("utf8")) - self.transport.write(b"\n") - except Exception: - # Something has gone wrong writing to the transport -- log it - # and break out of the while. - traceback.print_exc(file=sys.__stderr__) - break - - -@attr.s -@implementer(ILogObserver) -class TerseJSONToTCPLogObserver: - """ - An IObserver that writes JSON logs to a TCP target. - - Args: - hs (HomeServer): The homeserver that is being logged for. - host: The host of the logging target. - port: The logging target's port. - metadata: Metadata to be added to each log entry. - """ - - hs = attr.ib() - host = attr.ib(type=str) - port = attr.ib(type=int) - metadata = attr.ib(type=dict) - maximum_buffer = attr.ib(type=int) - _buffer = attr.ib(default=attr.Factory(deque), type=deque) - _connection_waiter = attr.ib(default=None, type=Optional[Deferred]) - _logger = attr.ib(default=attr.Factory(Logger)) - _producer = attr.ib(default=None, type=Optional[LogProducer]) - - def start(self) -> None: - - # Connect without DNS lookups if it's a direct IP. - try: - ip = ip_address(self.host) - if isinstance(ip, IPv4Address): - endpoint = TCP4ClientEndpoint( - self.hs.get_reactor(), self.host, self.port - ) - elif isinstance(ip, IPv6Address): - endpoint = TCP6ClientEndpoint( - self.hs.get_reactor(), self.host, self.port - ) - except ValueError: - endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port) - - factory = Factory.forProtocol(Protocol) - self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) - self._service.startService() - self._connect() - - def stop(self): - self._service.stopService() - - def _connect(self) -> None: - """ - Triggers an attempt to connect then write to the remote if not already writing. - """ - if self._connection_waiter: - return - - self._connection_waiter = self._service.whenConnected(failAfterFailures=1) - - @self._connection_waiter.addErrback - def fail(r): - r.printTraceback(file=sys.__stderr__) - self._connection_waiter = None - self._connect() - - @self._connection_waiter.addCallback - def writer(r): - # We have a connection. If we already have a producer, and its - # transport is the same, just trigger a resumeProducing. - if self._producer and r.transport is self._producer.transport: - self._producer.resumeProducing() - self._connection_waiter = None - return - - # If the producer is still producing, stop it. - if self._producer: - self._producer.stopProducing() - - # Make a new producer and start it. - self._producer = LogProducer(buffer=self._buffer, transport=r.transport) - r.transport.registerProducer(self._producer, True) - self._producer.resumeProducing() - self._connection_waiter = None - - def _handle_pressure(self) -> None: - """ - Handle backpressure by shedding events. - - The buffer will, in this order, until the buffer is below the maximum: - - Shed DEBUG events - - Shed INFO events - - Shed the middle 50% of the events. - """ - if len(self._buffer) <= self.maximum_buffer: - return - - # Strip out DEBUGs - self._buffer = deque( - filter(lambda event: event["level"] != "DEBUG", self._buffer) - ) - - if len(self._buffer) <= self.maximum_buffer: - return - - # Strip out INFOs - self._buffer = deque( - filter(lambda event: event["level"] != "INFO", self._buffer) - ) - - if len(self._buffer) <= self.maximum_buffer: - return - - # Cut the middle entries out - buffer_split = floor(self.maximum_buffer / 2) - - old_buffer = self._buffer - self._buffer = deque() - - for i in range(buffer_split): - self._buffer.append(old_buffer.popleft()) - - end_buffer = [] - for i in range(buffer_split): - end_buffer.append(old_buffer.pop()) - - self._buffer.extend(reversed(end_buffer)) - - def __call__(self, event: dict) -> None: - flattened = flatten_event(event, self.metadata, include_time=True) - self._buffer.append(flattened) - - # Handle backpressure, if it exists. - try: - self._handle_pressure() - except Exception: - # If handling backpressure fails,clear the buffer and log the - # exception. - self._buffer.clear() - self._logger.failure("Failed clearing backpressure") - - # Try and write immediately. - self._connect() +# The properties of a standard LogRecord. +_LOG_RECORD_ATTRIBUTES = { + "args", + "asctime", + "created", + "exc_info", + # exc_text isn't a public attribute, but is used to cache the result of formatException. + "exc_text", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "message", + "module", + "msecs", + "msg", + "name", + "pathname", + "process", + "processName", + "relativeCreated", + "stack_info", + "thread", + "threadName", +} + + +class JsonFormatter(logging.Formatter): + def format(self, record: logging.LogRecord) -> str: + event = { + "log": record.getMessage(), + "namespace": record.name, + "level": record.levelname, + } + + return self._format(record, event) + + def _format(self, record: logging.LogRecord, event: dict) -> str: + # Add any extra attributes to the event. + for key, value in record.__dict__.items(): + if key not in _LOG_RECORD_ATTRIBUTES: + event[key] = value + + return _encoder.encode(event) + + +class TerseJsonFormatter(JsonFormatter): + def format(self, record: logging.LogRecord) -> str: + event = { + "log": record.getMessage(), + "namespace": record.name, + "level": record.levelname, + "time": round(record.created, 2), + } + + return self._format(record, event) diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index ca0c774cc5..a507a83e93 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py
@@ -203,10 +203,6 @@ class _Sentinel: def copy_to(self, record): pass - def copy_to_twisted_log_entry(self, record): - record["request"] = None - record["scope"] = None - def start(self, rusage: "Optional[resource._RUsage]"): pass @@ -372,13 +368,6 @@ class LoggingContext: # we also track the current scope: record.scope = self.scope - def copy_to_twisted_log_entry(self, record) -> None: - """ - Copy logging fields from this context to a Twisted log record. - """ - record["request"] = self.request - record["scope"] = self.scope - def start(self, rusage: "Optional[resource._RUsage]") -> None: """ Record that this logcontext is currently running. @@ -542,13 +531,10 @@ class LoggingContext: class LoggingContextFilter(logging.Filter): """Logging filter that adds values from the current logging context to each record. - Args: - **defaults: Default values to avoid formatters complaining about - missing fields """ - def __init__(self, **defaults) -> None: - self.defaults = defaults + def __init__(self, request: str = ""): + self._default_request = request def filter(self, record) -> Literal[True]: """Add each fields from the logging contexts to the record. @@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - for key, value in self.defaults.items(): - setattr(record, key, value) + record.request = self._default_request # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for # robustness' sake. if context is not None: - context.copy_to(record) + # Logging is interested in the request. + record.request = context.request return True diff --git a/synapse/logging/filter.py b/synapse/logging/filter.py new file mode 100644
index 0000000000..1baf8dd679 --- /dev/null +++ b/synapse/logging/filter.py
@@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from typing_extensions import Literal + + +class MetadataFilter(logging.Filter): + """Logging filter that adds constant values to each record. + + Args: + metadata: Key-value pairs to add to each record. + """ + + def __init__(self, metadata: dict): + self._metadata = metadata + + def filter(self, record: logging.LogRecord) -> Literal[True]: + for key, value in self._metadata.items(): + setattr(record, key, value) + return True diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index e58850faff..ab586c318c 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -317,7 +317,7 @@ def ensure_active_span(message, ret=None): @contextlib.contextmanager -def _noop_context_manager(*args, **kwargs): +def noop_context_manager(*args, **kwargs): """Does exactly what it says on the tin""" yield @@ -413,7 +413,7 @@ def start_active_span( """ if opentracing is None: - return _noop_context_manager() + return noop_context_manager() return opentracing.tracer.start_active_span( operation_name, @@ -428,7 +428,7 @@ def start_active_span( def start_active_span_follows_from(operation_name, contexts): if opentracing is None: - return _noop_context_manager() + return noop_context_manager() references = [opentracing.follows_from(context) for context in contexts] scope = start_active_span(operation_name, references=references) @@ -459,7 +459,7 @@ def start_active_span_from_request( # Also, twisted uses byte arrays while opentracing expects strings. if opentracing is None: - return _noop_context_manager() + return noop_context_manager() header_dict = { k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders() @@ -497,7 +497,7 @@ def start_active_span_from_edu( """ if opentracing is None: - return _noop_context_manager() + return noop_context_manager() carrier = json_decoder.decode(edu_content.get("context", "{}")).get( "opentracing", {} diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index b8d2a8e8a9..cbf0dbb871 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py
@@ -502,6 +502,16 @@ build_info.labels( last_ticked = time.time() +# 3PID send info +threepid_send_requests = Histogram( + "synapse_threepid_send_requests_with_tries", + documentation="Number of requests for a 3pid token by try count. Note if" + " there is a request with try count of 4, then there would have been one" + " each for 1, 2 and 3", + buckets=(1, 2, 3, 4, 5, 10), + labelnames=("type", "reason"), +) + class ReactorLastSeenMetric: def collect(self): diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 5b73463504..70e0fa45d9 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py
@@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import threading from functools import wraps @@ -24,6 +23,8 @@ from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.opentracing import noop_context_manager, start_active_span +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: import resource @@ -166,7 +167,7 @@ class _BackgroundProcess: ) -def run_as_background_process(desc: str, func, *args, **kwargs): +def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs): """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -180,6 +181,9 @@ def run_as_background_process(desc: str, func, *args, **kwargs): Args: desc: a description for this background process type func: a function, which may return a Deferred or a coroutine + bg_start_span: Whether to start an opentracing span. Defaults to True. + Should only be disabled for processes that will not log to or tag + a span. args: positional args for func kwargs: keyword args for func @@ -195,16 +199,13 @@ def run_as_background_process(desc: str, func, *args, **kwargs): _background_process_start_count.labels(desc).inc() _background_process_in_flight_count.labels(desc).inc() - with BackgroundProcessLoggingContext(desc) as context: - context.request = "%s-%i" % (desc, count) - + with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context: try: - result = func(*args, **kwargs) - - if inspect.isawaitable(result): - result = await result - - return result + ctx = noop_context_manager() + if bg_start_span: + ctx = start_active_span(desc, tags={"request_id": context.request}) + with ctx: + return await maybe_awaitable(func(*args, **kwargs)) except Exception: logger.exception( "Background process '%s' threw an exception", desc, @@ -242,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext): __slots__ = ["_proc"] - def __init__(self, name: str): - super().__init__(name) + def __init__(self, name: str, request: Optional[str] = None): + super().__init__(name, request=request) self._proc = _BackgroundProcess(name, self) @@ -265,7 +266,7 @@ class BackgroundProcessLoggingContext(LoggingContext): super().__exit__(type, value, traceback) - # The background process has finished. We explictly remove and manually + # The background process has finished. We explicitly remove and manually # update the metrics here so that if nothing is scraping metrics the set # doesn't infinitely grow. with _bg_metrics_lock: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index fcbd5378c4..72ab5750cc 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -14,12 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Iterable, Optional, Tuple from twisted.internet import defer +from synapse.events import EventBase +from synapse.http.client import SimpleHttpClient from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.types import UserID +from synapse.storage.state import StateFilter +from synapse.types import JsonDict, UserID, create_requester + +if TYPE_CHECKING: + from synapse.server import HomeServer """ This package defines the 'stable' API which can be used by extension modules which @@ -42,6 +49,28 @@ class ModuleApi: self._store = hs.get_datastore() self._auth = hs.get_auth() self._auth_handler = auth_handler + self._server_name = hs.hostname + + # We expose these as properties below in order to attach a helpful docstring. + self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient + self._public_room_list_manager = PublicRoomListManager(hs) + + @property + def http_client(self): + """Allows making outbound HTTP requests to remote resources. + + An instance of synapse.http.client.SimpleHttpClient + """ + return self._http_client + + @property + def public_room_list_manager(self): + """Allows adding to, removing from and checking the status of rooms in the + public room list. + + An instance of synapse.module_api.PublicRoomListManager + """ + return self._public_room_list_manager def get_user_by_req(self, req, allow_guest=False): """Check the access_token provided for a request @@ -266,3 +295,99 @@ class ModuleApi: await self._auth_handler.complete_sso_login( registered_user_id, request, client_redirect_url, ) + + @defer.inlineCallbacks + def get_state_events_in_room( + self, room_id: str, types: Iterable[Tuple[str, Optional[str]]] + ) -> defer.Deferred: + """Gets current state events for the given room. + + (This is exposed for compatibility with the old SpamCheckerApi. We should + probably deprecate it and replace it with an async method in a subclass.) + + Args: + room_id: The room ID to get state events in. + types: The event type and state key (using None + to represent 'any') of the room state to acquire. + + Returns: + twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: + The filtered state events in the room. + """ + state_ids = yield defer.ensureDeferred( + self._store.get_filtered_current_state_ids( + room_id=room_id, state_filter=StateFilter.from_types(types) + ) + ) + state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) + return state.values() + + async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase: + """Create and send an event into a room. Membership events are currently not supported. + + Args: + event_dict: A dictionary representing the event to send. + Required keys are `type`, `room_id`, `sender` and `content`. + + Returns: + The event that was sent. If state event deduplication happened, then + the previous, duplicate event instead. + + Raises: + SynapseError if the event was not allowed. + """ + # Create a requester object + requester = create_requester( + event_dict["sender"], authenticated_entity=self._server_name + ) + + # Create and send the event + ( + event, + _, + ) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event( + requester, event_dict, ratelimit=False, ignore_shadow_ban=True, + ) + + return event + + +class PublicRoomListManager: + """Contains methods for adding to, removing from and querying whether a room + is in the public room list. + """ + + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastore() + + async def room_is_in_public_room_list(self, room_id: str) -> bool: + """Checks whether a room is in the public room list. + + Args: + room_id: The ID of the room. + + Returns: + Whether the room is in the public room list. Returns False if the room does + not exist. + """ + room = await self._store.get_room(room_id) + if not room: + return False + + return room.get("is_public", False) + + async def add_room_to_public_room_list(self, room_id: str) -> None: + """Publishes a room to the public room list. + + Args: + room_id: The ID of the room. + """ + await self._store.set_room_is_public(room_id, True) + + async def remove_room_from_public_room_list(self, room_id: str) -> None: + """Removes a room from the public room list. + + Args: + room_id: The ID of the room. + """ + await self._store.set_room_is_public(room_id, False) diff --git a/synapse/notifier.py b/synapse/notifier.py
index 59415f6f88..c4c8bb271d 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py
@@ -28,19 +28,19 @@ from typing import ( Union, ) +import attr from prometheus_client import Counter from twisted.internet import defer import synapse.server -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.errors import AuthError from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import PreserveLoggingContext from synapse.logging.utils import log_function from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.streams.config import PaginationConfig from synapse.types import ( Collection, @@ -174,6 +174,17 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): return bool(self.events) +@attr.s(slots=True, frozen=True) +class _PendingRoomEventEntry: + event_pos = attr.ib(type=PersistedEventPosition) + extra_users = attr.ib(type=Collection[UserID]) + + room_id = attr.ib(type=str) + type = attr.ib(type=str) + state_key = attr.ib(type=Optional[str]) + membership = attr.ib(type=Optional[str]) + + class Notifier: """ This class is responsible for notifying any listeners when there are new events available for it. @@ -191,9 +202,7 @@ class Notifier: self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() - self.pending_new_room_events = ( - [] - ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]] + self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -256,7 +265,29 @@ class Notifier: max_room_stream_token: RoomStreamToken, extra_users: Collection[UserID] = [], ): - """ Used by handlers to inform the notifier something has happened + """Unwraps event and calls `on_new_room_event_args`. + """ + self.on_new_room_event_args( + event_pos=event_pos, + room_id=event.room_id, + event_type=event.type, + state_key=event.get("state_key"), + membership=event.content.get("membership"), + max_room_stream_token=max_room_stream_token, + extra_users=extra_users, + ) + + def on_new_room_event_args( + self, + room_id: str, + event_type: str, + state_key: Optional[str], + membership: Optional[str], + event_pos: PersistedEventPosition, + max_room_stream_token: RoomStreamToken, + extra_users: Collection[UserID] = [], + ): + """Used by handlers to inform the notifier something has happened in the room, room event wise. This triggers the notifier to wake up any listeners that are @@ -267,7 +298,16 @@ class Notifier: until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append((event_pos, event, extra_users)) + self.pending_new_room_events.append( + _PendingRoomEventEntry( + event_pos=event_pos, + extra_users=extra_users, + room_id=room_id, + type=event_type, + state_key=state_key, + membership=membership, + ) + ) self._notify_pending_new_room_events(max_room_stream_token) self.notify_replication() @@ -285,18 +325,19 @@ class Notifier: users = set() # type: Set[UserID] rooms = set() # type: Set[str] - for event_pos, event, extra_users in pending: - if event_pos.persisted_after(max_room_stream_token): - self.pending_new_room_events.append((event_pos, event, extra_users)) + for entry in pending: + if entry.event_pos.persisted_after(max_room_stream_token): + self.pending_new_room_events.append(entry) else: if ( - event.type == EventTypes.Member - and event.membership == Membership.JOIN + entry.type == EventTypes.Member + and entry.membership == Membership.JOIN + and entry.state_key ): - self._user_joined_room(event.state_key, event.room_id) + self._user_joined_room(entry.state_key, entry.room_id) - users.update(extra_users) - rooms.add(event.room_id) + users.update(entry.extra_users) + rooms.add(entry.room_id) if users or rooms: self.on_new_event( @@ -310,28 +351,37 @@ class Notifier: """ # poke any interested application service. - run_as_background_process( - "_notify_app_services", self._notify_app_services, max_room_stream_token - ) - - run_as_background_process( - "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token - ) + self._notify_app_services(max_room_stream_token) + self._notify_pusher_pool(max_room_stream_token) if self.federation_sender: - self.federation_sender.notify_new_events(max_room_stream_token.stream) + self.federation_sender.notify_new_events(max_room_stream_token) - async def _notify_app_services(self, max_room_stream_token: RoomStreamToken): + def _notify_app_services(self, max_room_stream_token: RoomStreamToken): try: - await self.appservice_handler.notify_interested_services( - max_room_stream_token.stream + self.appservice_handler.notify_interested_services(max_room_stream_token) + except Exception: + logger.exception("Error notifying application services of event") + + def _notify_app_services_ephemeral( + self, + stream_key: str, + new_token: Union[int, RoomStreamToken], + users: Collection[Union[str, UserID]] = [], + ): + try: + stream_token = None + if isinstance(new_token, int): + stream_token = new_token + self.appservice_handler.notify_interested_services_ephemeral( + stream_key, stream_token, users ) except Exception: logger.exception("Error notifying application services of event") - async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): + def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): try: - await self._pusher_pool.on_new_notifications(max_room_stream_token.stream) + self._pusher_pool.on_new_notifications(max_room_stream_token) except Exception: logger.exception("Error pusher pool of event") @@ -339,7 +389,7 @@ class Notifier: self, stream_key: str, new_token: Union[int, RoomStreamToken], - users: Collection[UserID] = [], + users: Collection[Union[str, UserID]] = [], rooms: Collection[str] = [], ): """ Used to inform listeners that something has happened event wise. @@ -367,8 +417,13 @@ class Notifier: self.notify_replication() + # Notify appservices + self._notify_app_services_ephemeral( + stream_key, new_token, users, + ) + def on_new_replication_data(self) -> None: - """Used to inform replication listeners that something has happend + """Used to inform replication listeners that something has happened without waking up any of the normal user event streams""" self.notify_replication() @@ -556,7 +611,9 @@ class Notifier: room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: - return state.content["history_visibility"] == "world_readable" + return ( + state.content["history_visibility"] == HistoryVisibility.WORLD_READABLE + ) else: return False diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 5a437f9810..f4f7ec96f8 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py
@@ -13,7 +13,111 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc +from typing import TYPE_CHECKING, Any, Dict, Optional + +import attr + +from synapse.types import JsonDict, RoomStreamToken + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + + +@attr.s(slots=True) +class PusherConfig: + """Parameters necessary to configure a pusher.""" + + id = attr.ib(type=Optional[str]) + user_name = attr.ib(type=str) + access_token = attr.ib(type=Optional[int]) + profile_tag = attr.ib(type=str) + kind = attr.ib(type=str) + app_id = attr.ib(type=str) + app_display_name = attr.ib(type=str) + device_display_name = attr.ib(type=str) + pushkey = attr.ib(type=str) + ts = attr.ib(type=int) + lang = attr.ib(type=Optional[str]) + data = attr.ib(type=Optional[JsonDict]) + last_stream_ordering = attr.ib(type=int) + last_success = attr.ib(type=Optional[int]) + failing_since = attr.ib(type=Optional[int]) + + def as_dict(self) -> Dict[str, Any]: + """Information that can be retrieved about a pusher after creation.""" + return { + "app_display_name": self.app_display_name, + "app_id": self.app_id, + "data": self.data, + "device_display_name": self.device_display_name, + "kind": self.kind, + "lang": self.lang, + "profile_tag": self.profile_tag, + "pushkey": self.pushkey, + } + + +@attr.s(slots=True) +class ThrottleParams: + """Parameters for controlling the rate of sending pushes via email.""" + + last_sent_ts = attr.ib(type=int) + throttle_ms = attr.ib(type=int) + + +class Pusher(metaclass=abc.ABCMeta): + def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): + self.hs = hs + self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() + + self.pusher_id = pusher_config.id + self.user_id = pusher_config.user_name + self.app_id = pusher_config.app_id + self.pushkey = pusher_config.pushkey + + self.last_stream_ordering = pusher_config.last_stream_ordering + + # This is the highest stream ordering we know it's safe to process. + # When new events arrive, we'll be given a window of new events: we + # should honour this rather than just looking for anything higher + # because of potential out-of-order event serialisation. + self.max_stream_ordering = self.store.get_room_max_stream_ordering() + + def on_new_notifications(self, max_token: RoomStreamToken) -> None: + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + max_stream_ordering = max_token.stream + + self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) + self._start_processing() + + @abc.abstractmethod + def _start_processing(self): + """Start processing push notifications.""" + raise NotImplementedError() + + @abc.abstractmethod + def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def on_started(self, have_notifs: bool) -> None: + """Called when this pusher has been started. + + Args: + should_check_for_notifs: Whether we should immediately + check for push to send. Set to False only if it's known there + is nothing to send + """ + raise NotImplementedError() + + @abc.abstractmethod + def on_stop(self) -> None: + raise NotImplementedError() + class PusherConfigException(Exception): - def __init__(self, msg): - super().__init__(msg) + """An error occurred when creating a pusher.""" diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index fabc9ba126..aaed28650d 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py
@@ -14,19 +14,22 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.util.metrics import Measure -from .bulk_push_rule_evaluator import BulkPushRuleEvaluator +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) class ActionGenerator: - def __init__(self, hs): - self.hs = hs + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() - self.store = hs.get_datastore() self.bulk_evaluator = BulkPushRuleEvaluator(hs) # really we want to get all user ids and all profile tags too, # since we want the actions for each profile tag for every user and @@ -35,6 +38,8 @@ class ActionGenerator: # event stream, so we just run the rules for a client with no profile # tag (ie. we just need all the users). - async def handle_push_actions_for_event(self, event, context): + async def handle_push_actions_for_event( + self, event: EventBase, context: EventContext + ) -> None: with Measure(self.clock, "action_for_event_by_user"): await self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 8047873ff1..6211506990 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py
@@ -15,16 +15,19 @@ # limitations under the License. import copy +from typing import Any, Dict, List from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP -def list_with_base_rules(rawrules, use_new_defaults=False): +def list_with_base_rules( + rawrules: List[Dict[str, Any]], use_new_defaults: bool = False +) -> List[Dict[str, Any]]: """Combine the list of rules set by the user with the default push rules Args: - rawrules(list): The rules the user has modified or set. - use_new_defaults(bool): Whether to use the new experimental default rules when + rawrules: The rules the user has modified or set. + use_new_defaults: Whether to use the new experimental default rules when appending or prepending default rules. Returns: @@ -37,7 +40,7 @@ def list_with_base_rules(rawrules, use_new_defaults=False): modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0} # Remove the modified base rules from the list, They'll be added back - # in the default postions in the list. + # in the default positions in the list. rawrules = [r for r in rawrules if r["priority_class"] >= 0] # shove the server default rules for each kind onto the end of each @@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False): return ruleslist -def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False): +def make_base_append_rules( + kind: str, + modified_base_rules: Dict[str, Dict[str, Any]], + use_new_defaults: bool = False, +) -> List[Dict[str, Any]]: rules = [] if kind == "override": @@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False): rules = copy.deepcopy(rules) for r in rules: # Only modify the actions, keep the conditions the same. + assert isinstance(r["rule_id"], str) modified = modified_base_rules.get(r["rule_id"]) if modified: r["actions"] = modified["actions"] @@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False): return rules -def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False): +def make_base_prepend_rules( + kind: str, + modified_base_rules: Dict[str, Dict[str, Any]], + use_new_defaults: bool = False, +) -> List[Dict[str, Any]]: rules = [] if kind == "override": @@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False): rules = copy.deepcopy(rules) for r in rules: # Only modify the actions, keep the conditions the same. + assert isinstance(r["rule_id"], str) modified = modified_base_rules.get(r["rule_id"]) if modified: r["actions"] = modified["actions"] @@ -498,6 +511,30 @@ BASE_APPEND_UNDERRIDE_RULES = [ ], "actions": ["notify", {"set_tweak": "highlight", "value": False}], }, + { + "rule_id": "global/underride/.im.vector.jitsi", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "im.vector.modular.widgets", + "_id": "_type_modular_widgets", + }, + { + "kind": "event_match", + "key": "content.type", + "pattern": "jitsi", + "_id": "_content_type_jitsi", + }, + { + "kind": "event_match", + "key": "state_key", + "pattern": "*", + "_id": "_is_state_event", + }, + ], + "actions": ["notify", {"set_tweak": "highlight", "value": False}], + }, ] diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index c440f2545c..10f27e4378 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,8 +15,9 @@ # limitations under the License. import logging -from collections import namedtuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union +import attr from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership, RelationTypes @@ -25,15 +26,16 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.util.async_helpers import Linearizer -from synapse.util.caches import register_cache -from synapse.util.caches.descriptors import cached +from synapse.util.caches import CacheMetric, register_cache +from synapse.util.caches.descriptors import lru_cache +from synapse.util.caches.lrucache import LruCache from .push_rule_evaluator import PushRuleEvaluatorForEvent -logger = logging.getLogger(__name__) - +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer -rules_by_room = {} +logger = logging.getLogger(__name__) push_rules_invalidation_counter = Counter( @@ -100,7 +102,7 @@ class BulkPushRuleEvaluator: room at once. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() @@ -112,7 +114,9 @@ class BulkPushRuleEvaluator: resizable=False, ) - async def _get_rules_for_event(self, event, context): + async def _get_rules_for_event( + self, event: EventBase, context: EventContext + ) -> Dict[str, List[Dict[str, Any]]]: """This gets the rules for all users in the room at the time of the event, as well as the push rules for the invitee if the event is an invite. @@ -120,7 +124,7 @@ class BulkPushRuleEvaluator: dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = await self._get_rules_for_room(room_id) + rules_for_room = self._get_rules_for_room(room_id) rules_by_user = await rules_for_room.get_rules(event, context) @@ -138,12 +142,9 @@ class BulkPushRuleEvaluator: return rules_by_user - @cached() - def _get_rules_for_room(self, room_id): + @lru_cache() + def _get_rules_for_room(self, room_id: str) -> "RulesForRoom": """Get the current RulesForRoom object for the given room id - - Returns: - RulesForRoom """ # It's important that RulesForRoom gets added to self._get_rules_for_room.cache # before any lookup methods get called on it as otherwise there may be @@ -155,20 +156,21 @@ class BulkPushRuleEvaluator: self.room_push_rule_cache_metrics, ) - async def _get_power_levels_and_sender_level(self, event, context): + async def _get_power_levels_and_sender_level( + self, event: EventBase, context: EventContext + ) -> Tuple[dict, int]: prev_state_ids = await context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and # not having a power level event is an extreme edge case - pl_event = await self.store.get_event(pl_event_id) - auth_events = {POWER_KEY: pl_event} + auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=False ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_dict = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -176,7 +178,9 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def action_for_event_by_user(self, event, context) -> None: + async def action_for_event_by_user( + self, event: EventBase, context: EventContext + ) -> None: """Given an event and context, evaluate the push rules, check if the message should increment the unread count, and insert the results into the event_push_actions_staging table. @@ -184,7 +188,7 @@ class BulkPushRuleEvaluator: count_as_unread = _should_count_as_unread(event, context) rules_by_user = await self._get_rules_for_event(event, context) - actions_by_user = {} + actions_by_user = {} # type: Dict[str, List[Union[dict, str]]] room_members = await self.store.get_joined_users_from_context(event, context) @@ -197,7 +201,7 @@ class BulkPushRuleEvaluator: event, len(room_members), sender_power_level, power_levels ) - condition_cache = {} + condition_cache = {} # type: Dict[str, bool] for uid, rules in rules_by_user.items(): if event.sender == uid: @@ -248,7 +252,13 @@ class BulkPushRuleEvaluator: ) -def _condition_checker(evaluator, conditions, uid, display_name, cache): +def _condition_checker( + evaluator: PushRuleEvaluatorForEvent, + conditions: List[dict], + uid: str, + display_name: str, + cache: Dict[str, bool], +) -> bool: for cond in conditions: _id = cond.get("_id", None) if _id: @@ -275,14 +285,20 @@ class RulesForRoom: the entire cache for the room. """ - def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics): + def __init__( + self, + hs: "HomeServer", + room_id: str, + rules_for_room_cache: LruCache, + room_push_rule_cache_metrics: CacheMetric, + ): """ Args: - hs (HomeServer) - room_id (str) - rules_for_room_cache(Cache): The cache object that caches these + hs: The HomeServer object. + room_id: The room ID. + rules_for_room_cache: The cache object that caches these RoomsForUser objects. - room_push_rule_cache_metrics (CacheMetric) + room_push_rule_cache_metrics: The metrics object """ self.room_id = room_id self.is_mine_id = hs.is_mine_id @@ -291,8 +307,10 @@ class RulesForRoom: self.linearizer = Linearizer(name="rules_for_room") - self.member_map = {} # event_id -> (user_id, state) - self.rules_by_user = {} # user_id -> rules + # event_id -> (user_id, state) + self.member_map = {} # type: Dict[str, Tuple[str, str]] + # user_id -> rules + self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]] # The last state group we updated the caches for. If the state_group of # a new event comes along, we know that we can just return the cached @@ -312,7 +330,7 @@ class RulesForRoom: # calculate push for) # These never need to be invalidated as we will never set up push for # them. - self.uninteresting_user_set = set() + self.uninteresting_user_set = set() # type: Set[str] # We need to be clever on the invalidating caches callbacks, as # otherwise the invalidation callback holds a reference to the object, @@ -322,7 +340,9 @@ class RulesForRoom: # to self around in the callback. self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) - async def get_rules(self, event, context): + async def get_rules( + self, event: EventBase, context: EventContext + ) -> Dict[str, List[Dict[str, dict]]]: """Given an event context return the rules for all users who are currently in the room. """ @@ -353,6 +373,8 @@ class RulesForRoom: else: current_state_ids = await context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() + # Ensure the state IDs exist. + assert current_state_ids is not None push_rules_state_size_counter.inc(len(current_state_ids)) @@ -390,12 +412,12 @@ class RulesForRoom: continue # If a user has left a room we remove their push rule. If they - # joined then we readd it later in _update_rules_with_member_event_ids + # joined then we re-add it later in _update_rules_with_member_event_ids ret_rules_by_user.pop(user_id, None) missing_member_event_ids[user_id] = event_id if missing_member_event_ids: - # If we have some memebr events we haven't seen, look them up + # If we have some member events we haven't seen, look them up # and fetch push rules for them if appropriate. logger.debug("Found new member events %r", missing_member_event_ids) await self._update_rules_with_member_event_ids( @@ -417,18 +439,23 @@ class RulesForRoom: return ret_rules_by_user async def _update_rules_with_member_event_ids( - self, ret_rules_by_user, member_event_ids, state_group, event - ): + self, + ret_rules_by_user: Dict[str, list], + member_event_ids: Dict[str, str], + state_group: Optional[int], + event: EventBase, + ) -> None: """Update the partially filled rules_by_user dict by fetching rules for any newly joined users in the `member_event_ids` list. Args: - ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets + ret_rules_by_user: Partially filled dict of push rules. Gets updated with any new rules. - member_event_ids (dict): Dict of user id to event id for membership events + member_event_ids: Dict of user id to event id for membership events that have happened since the last time we filled rules_by_user state_group: The state group we are currently computing push rules for. Used when updating the cache. + event: The event we are currently computing push rules for. """ sequence = self.sequence @@ -446,19 +473,19 @@ class RulesForRoom: if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) - user_ids = { + joined_user_ids = { user_id for user_id, membership in members.values() if membership == Membership.JOIN } - logger.debug("Joined: %r", user_ids) + logger.debug("Joined: %r", joined_user_ids) # Previously we only considered users with pushers or read receipts in that # room. We can't do this anymore because we use push actions to calculate unread # counts, which don't rely on the user having pushers or sent a read receipt into # the room. Therefore we just need to filter for local users here. - user_ids = list(filter(self.is_mine_id, user_ids)) + user_ids = list(filter(self.is_mine_id, joined_user_ids)) rules_by_user = await self.store.bulk_get_push_rules( user_ids, on_invalidate=self.invalidate_all_cb @@ -470,7 +497,7 @@ class RulesForRoom: self.update_cache(sequence, members, ret_rules_by_user, state_group) - def invalidate_all(self): + def invalidate_all(self) -> None: # Note: Don't hand this function directly to an invalidation callback # as it keeps a reference to self and will stop this instance from being # GC'd if it gets dropped from the rules_to_user cache. Instead use @@ -482,20 +509,28 @@ class RulesForRoom: self.rules_by_user = {} push_rules_invalidation_counter.inc() - def update_cache(self, sequence, members, rules_by_user, state_group): + def update_cache(self, sequence, members, rules_by_user, state_group) -> None: if sequence == self.sequence: self.member_map.update(members) self.rules_by_user = rules_by_user self.state_group = state_group -class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))): - # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, - # which namedtuple does for us (i.e. two _CacheContext are the same if - # their caches and keys match). This is important in particular to - # dedupe when we add callbacks to lru cache nodes, otherwise the number - # of callbacks would grow. - def __call__(self): +@attr.attrs(slots=True, frozen=True) +class _Invalidation: + # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules, + # which means that it it is stored on the bulk_get_push_rules cache entry. In order + # to ensure that we don't accumulate lots of redunant callbacks on the cache entry, + # we need to ensure that two _Invalidation objects are "equal" if they refer to the + # same `cache` and `room_id`. + # + # attrs provides suitable __hash__ and __eq__ methods, provided we remember to + # set `frozen=True`. + + cache = attr.ib(type=LruCache) + room_id = attr.ib(type=str) + + def __call__(self) -> None: rules = self.cache.get(self.room_id, None, update_metrics=False) if rules: rules.invalidate_all() diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index a59b639f15..0cadba761a 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py
@@ -14,24 +14,27 @@ # limitations under the License. import copy +from typing import Any, Dict, List, Optional from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP +from synapse.types import UserID -def format_push_rules_for_user(user, ruleslist): +def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]: """Converts a list of rawrules and a enabled map into nested dictionaries to match the Matrix client-server format for push rules""" # We're going to be mutating this a lot, so do a deep copy ruleslist = copy.deepcopy(ruleslist) - rules = {"global": {}, "device": {}} + rules = { + "global": {}, + "device": {}, + } # type: Dict[str, Dict[str, List[Dict[str, Any]]]] rules["global"] = _add_empty_priority_class_arrays(rules["global"]) for r in ruleslist: - rulearray = None - template_name = _priority_class_to_template_name(r["priority_class"]) # Remove internal stuff. @@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist): return rules -def _add_empty_priority_class_arrays(d): +def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]: for pc in PRIORITY_CLASS_MAP.keys(): d[pc] = [] return d -def _rule_to_template(rule): +def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]: unscoped_rule_id = None if "rule_id" in rule: unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"]) @@ -82,6 +85,10 @@ def _rule_to_template(rule): return None templaterule = {"actions": rule["actions"]} templaterule["pattern"] = thecond["pattern"] + else: + # This should not be reached unless this function is not kept in sync + # with PRIORITY_CLASS_INVERSE_MAP. + raise ValueError("Unexpected template_name: %s" % (template_name,)) if unscoped_rule_id: templaterule["rule_id"] = unscoped_rule_id @@ -90,9 +97,9 @@ def _rule_to_template(rule): return templaterule -def _rule_id_from_namespaced(in_rule_id): +def _rule_id_from_namespaced(in_rule_id: str) -> str: return in_rule_id.split("/")[-1] -def _priority_class_to_template_name(pc): +def _priority_class_to_template_name(pc: int) -> str: return PRIORITY_CLASS_INVERSE_MAP[pc] diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 28bd8ab748..4ac1b31748 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py
@@ -14,10 +14,17 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Dict, List, Optional +from twisted.internet.base import DelayedCall from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.push import Pusher, PusherConfig, ThrottleParams +from synapse.push.mailer import Mailer + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -45,7 +52,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000 INCLUDE_ALL_UNREAD_NOTIFS = False -class EmailPusher: +class EmailPusher(Pusher): """ A pusher that sends email notifications about events (approximately) when they happen. @@ -53,37 +60,30 @@ class EmailPusher: factor out the common parts """ - def __init__(self, hs, pusherdict, mailer): - self.hs = hs + def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer): + super().__init__(hs, pusher_config) self.mailer = mailer self.store = self.hs.get_datastore() - self.clock = self.hs.get_clock() - self.pusher_id = pusherdict["id"] - self.user_id = pusherdict["user_name"] - self.app_id = pusherdict["app_id"] - self.email = pusherdict["pushkey"] - self.last_stream_ordering = pusherdict["last_stream_ordering"] - self.timed_call = None - self.throttle_params = None - - # See httppusher - self.max_stream_ordering = None + self.email = pusher_config.pushkey + self.timed_call = None # type: Optional[DelayedCall] + self.throttle_params = {} # type: Dict[str, ThrottleParams] + self._inited = False self._is_processing = False - def on_started(self, should_check_for_notifs): + def on_started(self, should_check_for_notifs: bool) -> None: """Called when this pusher has been started. Args: - should_check_for_notifs (bool): Whether we should immediately + should_check_for_notifs: Whether we should immediately check for push to send. Set to False only if it's known there is nothing to send """ if should_check_for_notifs and self.mailer is not None: self._start_processing() - def on_stop(self): + def on_stop(self) -> None: if self.timed_call: try: self.timed_call.cancel() @@ -91,32 +91,23 @@ class EmailPusher: pass self.timed_call = None - def on_new_notifications(self, max_stream_ordering): - if self.max_stream_ordering: - self.max_stream_ordering = max( - max_stream_ordering, self.max_stream_ordering - ) - else: - self.max_stream_ordering = max_stream_ordering - self._start_processing() - - def on_new_receipts(self, min_stream_id, max_stream_id): + def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: # We could wake up and cancel the timer but there tend to be quite a # lot of read receipts so it's probably less work to just let the # timer fire pass - def on_timer(self): + def on_timer(self) -> None: self.timed_call = None self._start_processing() - def _start_processing(self): + def _start_processing(self) -> None: if self._is_processing: return run_as_background_process("emailpush.process", self._process) - def _pause_processing(self): + def _pause_processing(self) -> None: """Used by tests to temporarily pause processing of events. Asserts that its not currently processing. @@ -124,25 +115,27 @@ class EmailPusher: assert not self._is_processing self._is_processing = True - def _resume_processing(self): + def _resume_processing(self) -> None: """Used by tests to resume processing of events after pausing. """ assert self._is_processing self._is_processing = False self._start_processing() - async def _process(self): + async def _process(self) -> None: # we should never get here if we are already processing assert not self._is_processing try: self._is_processing = True - if self.throttle_params is None: + if not self._inited: # this is our first loop: load up the throttle params + assert self.pusher_id is not None self.throttle_params = await self.store.get_throttle_params_by_room( self.pusher_id ) + self._inited = True # if the max ordering changes while we're running _unsafe_process, # call it again, and so on until we've caught up. @@ -157,17 +150,18 @@ class EmailPusher: finally: self._is_processing = False - async def _unsafe_process(self): + async def _unsafe_process(self) -> None: """ Main logic of the push loop without the wrapper function that sets up logging, measures and guards against multiple instances of it being run. """ start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering - fn = self.store.get_unread_push_actions_for_user_in_range_for_email - unprocessed = await fn(self.user_id, start, self.max_stream_ordering) + unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email( + self.user_id, start, self.max_stream_ordering + ) - soonest_due_at = None + soonest_due_at = None # type: Optional[int] if not unprocessed: await self.save_last_stream_ordering_and_success(self.max_stream_ordering) @@ -224,11 +218,9 @@ class EmailPusher: self.seconds_until(soonest_due_at), self.on_timer ) - async def save_last_stream_ordering_and_success(self, last_stream_ordering): - if last_stream_ordering is None: - # This happens if we haven't yet processed anything - return - + async def save_last_stream_ordering_and_success( + self, last_stream_ordering: int + ) -> None: self.last_stream_ordering = last_stream_ordering pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( self.app_id, @@ -242,28 +234,30 @@ class EmailPusher: # lets just stop and return. self.on_stop() - def seconds_until(self, ts_msec): + def seconds_until(self, ts_msec: int) -> float: secs = (ts_msec - self.clock.time_msec()) / 1000 return max(secs, 0) - def get_room_throttle_ms(self, room_id): + def get_room_throttle_ms(self, room_id: str) -> int: if room_id in self.throttle_params: - return self.throttle_params[room_id]["throttle_ms"] + return self.throttle_params[room_id].throttle_ms else: return 0 - def get_room_last_sent_ts(self, room_id): + def get_room_last_sent_ts(self, room_id: str) -> int: if room_id in self.throttle_params: - return self.throttle_params[room_id]["last_sent_ts"] + return self.throttle_params[room_id].last_sent_ts else: return 0 - def room_ready_to_notify_at(self, room_id): + def room_ready_to_notify_at(self, room_id: str) -> int: """ Determines whether throttling should prevent us from sending an email for the given room - Returns: The timestamp when we are next allowed to send an email notif - for this room + + Returns: + The timestamp when we are next allowed to send an email notif + for this room """ last_sent_ts = self.get_room_last_sent_ts(room_id) throttle_ms = self.get_room_throttle_ms(room_id) @@ -271,7 +265,9 @@ class EmailPusher: may_send_at = last_sent_ts + throttle_ms return may_send_at - async def sent_notif_update_throttle(self, room_id, notified_push_action): + async def sent_notif_update_throttle( + self, room_id: str, notified_push_action: dict + ) -> None: # We have sent a notification, so update the throttle accordingly. # If the event that triggered the notif happened more than # THROTTLE_RESET_AFTER_MS after the previous one that triggered a @@ -301,15 +297,15 @@ class EmailPusher: new_throttle_ms = min( current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS ) - self.throttle_params[room_id] = { - "last_sent_ts": self.clock.time_msec(), - "throttle_ms": new_throttle_ms, - } + self.throttle_params[room_id] = ThrottleParams( + self.clock.time_msec(), new_throttle_ms, + ) + assert self.pusher_id is not None await self.store.set_throttle_params( self.pusher_id, room_id, self.throttle_params[room_id] ) - async def send_notification(self, push_actions, reason): + async def send_notification(self, push_actions: List[dict], reason: dict) -> None: logger.info("Sending notif email for user %r", self.user_id) await self.mailer.send_notification_mail( diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 26706bf3e1..e048b0d59e 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py
@@ -14,18 +14,24 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import urllib.parse +from typing import TYPE_CHECKING, Any, Dict, Iterable, Union from prometheus_client import Counter from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.push import PusherConfigException +from synapse.push import Pusher, PusherConfig, PusherConfigException from . import push_rule_evaluator, push_tools +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) http_push_processed_counter = Counter( @@ -49,94 +55,89 @@ http_badges_failed_counter = Counter( ) -class HttpPusher: +class HttpPusher(Pusher): INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes MAX_BACKOFF_SEC = 60 * 60 # This one's in ms because we compare it against the clock GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000 - def __init__(self, hs, pusherdict): - self.hs = hs - self.store = self.hs.get_datastore() + def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): + super().__init__(hs, pusher_config) self.storage = self.hs.get_storage() - self.clock = self.hs.get_clock() - self.state_handler = self.hs.get_state_handler() - self.user_id = pusherdict["user_name"] - self.app_id = pusherdict["app_id"] - self.app_display_name = pusherdict["app_display_name"] - self.device_display_name = pusherdict["device_display_name"] - self.pushkey = pusherdict["pushkey"] - self.pushkey_ts = pusherdict["ts"] - self.data = pusherdict["data"] - self.last_stream_ordering = pusherdict["last_stream_ordering"] + self.app_display_name = pusher_config.app_display_name + self.device_display_name = pusher_config.device_display_name + self.pushkey_ts = pusher_config.ts + self.data = pusher_config.data self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.failing_since = pusherdict["failing_since"] + self.failing_since = pusher_config.failing_since self.timed_call = None self._is_processing = False + self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room - # This is the highest stream ordering we know it's safe to process. - # When new events arrive, we'll be given a window of new events: we - # should honour this rather than just looking for anything higher - # because of potential out-of-order event serialisation. This starts - # off as None though as we don't know any better. - self.max_stream_ordering = None - - if "data" not in pusherdict: - raise PusherConfigException("No 'data' key for HTTP pusher") - self.data = pusherdict["data"] + self.data = pusher_config.data + if self.data is None: + raise PusherConfigException("'data' key can not be null for HTTP pusher") self.name = "%s/%s/%s" % ( - pusherdict["user_name"], - pusherdict["app_id"], - pusherdict["pushkey"], + pusher_config.user_name, + pusher_config.app_id, + pusher_config.pushkey, ) - if self.data is None: - raise PusherConfigException("data can not be null for HTTP pusher") - + # Validate that there's a URL and it is of the proper form. if "url" not in self.data: raise PusherConfigException("'url' required in data for HTTP pusher") - self.url = self.data["url"] - self.http_client = hs.get_proxied_http_client() + + url = self.data["url"] + if not isinstance(url, str): + raise PusherConfigException("'url' must be a string") + url_parts = urllib.parse.urlparse(url) + # Note that the specification also says the scheme must be HTTPS, but + # it isn't up to the homeserver to verify that. + if url_parts.path != "/_matrix/push/v1/notify": + raise PusherConfigException( + "'url' must have a path of '/_matrix/push/v1/notify'" + ) + + self.url = url + self.http_client = hs.get_proxied_blacklisted_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] - def on_started(self, should_check_for_notifs): + def on_started(self, should_check_for_notifs: bool) -> None: """Called when this pusher has been started. Args: - should_check_for_notifs (bool): Whether we should immediately + should_check_for_notifs: Whether we should immediately check for push to send. Set to False only if it's known there is nothing to send """ if should_check_for_notifs: self._start_processing() - def on_new_notifications(self, max_stream_ordering): - self.max_stream_ordering = max( - max_stream_ordering, self.max_stream_ordering or 0 - ) - self._start_processing() - - def on_new_receipts(self, min_stream_id, max_stream_id): + def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: # Note that the min here shouldn't be relied upon to be accurate. # We could check the receipts are actually m.read receipts here, # but currently that's the only type of receipt anyway... run_as_background_process("http_pusher.on_new_receipts", self._update_badge) - async def _update_badge(self): + async def _update_badge(self) -> None: # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. - badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count( + self.hs.get_datastore(), + self.user_id, + group_by_room=self._group_unread_count_by_room, + ) await self._send_badge(badge) - def on_timer(self): + def on_timer(self) -> None: self._start_processing() - def on_stop(self): + def on_stop(self) -> None: if self.timed_call: try: self.timed_call.cancel() @@ -144,13 +145,13 @@ class HttpPusher: pass self.timed_call = None - def _start_processing(self): + def _start_processing(self) -> None: if self._is_processing: return run_as_background_process("httppush.process", self._process) - async def _process(self): + async def _process(self) -> None: # we should never get here if we are already processing assert not self._is_processing @@ -169,15 +170,13 @@ class HttpPusher: finally: self._is_processing = False - async def _unsafe_process(self): + async def _unsafe_process(self) -> None: """ Looks for unset notifications and dispatch them, in order Never call this directly: use _process which will only allow this to run once per pusher. """ - - fn = self.store.get_unread_push_actions_for_user_in_range_for_http - unprocessed = await fn( + unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) @@ -246,17 +245,12 @@ class HttpPusher: ) self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = await self.store.update_pusher_last_stream_ordering( + await self.store.update_pusher_last_stream_ordering( self.app_id, self.pushkey, self.user_id, self.last_stream_ordering, ) - if not pusher_still_exists: - # The pusher has been deleted while we were processing, so - # lets just stop and return. - self.on_stop() - return self.failing_since = None await self.store.update_pusher_failing_since( @@ -272,12 +266,16 @@ class HttpPusher: ) break - async def _process_one(self, push_action): + async def _process_one(self, push_action: dict) -> bool: if "notify" not in push_action["actions"]: return True tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) - badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count( + self.hs.get_datastore(), + self.user_id, + group_by_room=self._group_unread_count_by_room, + ) event = await self.store.get_event(push_action["event_id"], allow_none=True) if event is None: @@ -299,7 +297,9 @@ class HttpPusher: await self.hs.remove_pusher(self.app_id, pk, self.user_id) return True - async def _build_notification_dict(self, event, tweaks, badge): + async def _build_notification_dict( + self, event: EventBase, tweaks: Dict[str, bool], badge: int + ) -> Dict[str, Any]: priority = "low" if ( event.type == EventTypes.Encrypted @@ -310,6 +310,8 @@ class HttpPusher: # or may do so (i.e. is encrypted so has unknown effects). priority = "high" + # This was checked in the __init__, but mypy doesn't seem to know that. + assert self.data is not None if self.data.get("format") == "event_id_only": d = { "notification": { @@ -329,9 +331,7 @@ class HttpPusher: } return d - ctx = await push_tools.get_context_for_event( - self.storage, self.state_handler, event, self.user_id - ) + ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id) d = { "notification": { @@ -371,7 +371,9 @@ class HttpPusher: return d - async def dispatch_push(self, event, tweaks, badge): + async def dispatch_push( + self, event: EventBase, tweaks: Dict[str, bool], badge: int + ) -> Union[bool, Iterable[str]]: notification_dict = await self._build_notification_dict(event, tweaks, badge) if not notification_dict: return [] diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 455a1acb46..4d875dcb91 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py
@@ -19,24 +19,28 @@ import logging import urllib.parse from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import Iterable, List, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar import bleach import jinja2 -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import StoreError from synapse.config.emailconfig import EmailSubjectConfig +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable from synapse.push.presentable_names import ( calculate_room_name, descriptor_from_member_events, name_from_member_event, ) -from synapse.types import UserID +from synapse.types import StateMap, UserID from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) T = TypeVar("T") @@ -93,7 +97,13 @@ ALLOWED_ATTRS = { class Mailer: - def __init__(self, hs, app_name, template_html, template_text): + def __init__( + self, + hs: "HomeServer", + app_name: str, + template_html: jinja2.Template, + template_text: jinja2.Template, + ): self.hs = hs self.template_html = template_html self.template_text = template_text @@ -108,17 +118,19 @@ class Mailer: logger.info("Created Mailer for app_name %s" % app_name) - async def send_password_reset_mail(self, email_address, token, client_secret, sid): + async def send_password_reset_mail( + self, email_address: str, token: str, client_secret: str, sid: str + ) -> None: """Send an email with a password reset link to a user Args: - email_address (str): Email address we're sending the password + email_address: Email address we're sending the password reset to - token (str): Unique token generated by the server to verify + token: Unique token generated by the server to verify the email was received - client_secret (str): Unique token generated by the client to + client_secret: Unique token generated by the client to group together multiple email sending attempts - sid (str): The generated session ID + sid: The generated session ID """ params = {"token": token, "client_secret": client_secret, "sid": sid} link = ( @@ -136,17 +148,19 @@ class Mailer: template_vars, ) - async def send_registration_mail(self, email_address, token, client_secret, sid): + async def send_registration_mail( + self, email_address: str, token: str, client_secret: str, sid: str + ) -> None: """Send an email with a registration confirmation link to a user Args: - email_address (str): Email address we're sending the registration + email_address: Email address we're sending the registration link to - token (str): Unique token generated by the server to verify + token: Unique token generated by the server to verify the email was received - client_secret (str): Unique token generated by the client to + client_secret: Unique token generated by the client to group together multiple email sending attempts - sid (str): The generated session ID + sid: The generated session ID """ params = {"token": token, "client_secret": client_secret, "sid": sid} link = ( @@ -164,18 +178,20 @@ class Mailer: template_vars, ) - async def send_add_threepid_mail(self, email_address, token, client_secret, sid): + async def send_add_threepid_mail( + self, email_address: str, token: str, client_secret: str, sid: str + ) -> None: """Send an email with a validation link to a user for adding a 3pid to their account Args: - email_address (str): Email address we're sending the validation link to + email_address: Email address we're sending the validation link to - token (str): Unique token generated by the server to verify the email was received + token: Unique token generated by the server to verify the email was received - client_secret (str): Unique token generated by the client to group together + client_secret: Unique token generated by the client to group together multiple email sending attempts - sid (str): The generated session ID + sid: The generated session ID """ params = {"token": token, "client_secret": client_secret, "sid": sid} link = ( @@ -194,8 +210,13 @@ class Mailer: ) async def send_notification_mail( - self, app_id, user_id, email_address, push_actions, reason - ): + self, + app_id: str, + user_id: str, + email_address: str, + push_actions: Iterable[Dict[str, Any]], + reason: Dict[str, Any], + ) -> None: """Send email regarding a user's room notifications""" rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions]) @@ -203,7 +224,7 @@ class Mailer: [pa["event_id"] for pa in push_actions] ) - notifs_by_room = {} + notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]] for pa in push_actions: notifs_by_room.setdefault(pa["room_id"], []).append(pa) @@ -262,7 +283,9 @@ class Mailer: await self.send_email(email_address, summary_text, template_vars) - async def send_email(self, email_address, subject, extra_template_vars): + async def send_email( + self, email_address: str, subject: str, extra_template_vars: Dict[str, Any] + ) -> None: """Send an email with the given information and template text""" try: from_string = self.hs.config.email_notif_from % {"app": self.app_name} @@ -315,11 +338,21 @@ class Mailer: ) async def get_room_vars( - self, room_id, user_id, notifs, notif_events, room_state_ids - ): - my_member_event_id = room_state_ids[("m.room.member", user_id)] - my_member_event = await self.store.get_event(my_member_event_id) - is_invite = my_member_event.content["membership"] == "invite" + self, + room_id: str, + user_id: str, + notifs: Iterable[Dict[str, Any]], + notif_events: Dict[str, EventBase], + room_state_ids: StateMap[str], + ) -> Dict[str, Any]: + # Check if one of the notifs is an invite event for the user. + is_invite = False + for n in notifs: + ev = notif_events[n["event_id"]] + if ev.type == EventTypes.Member and ev.state_key == user_id: + if ev.content.get("membership") == Membership.INVITE: + is_invite = True + break room_name = await calculate_room_name(self.store, room_state_ids, user_id) @@ -329,7 +362,7 @@ class Mailer: "notifs": [], "invite": is_invite, "link": self.make_room_link(room_id), - } + } # type: Dict[str, Any] if not is_invite: for n in notifs: @@ -360,7 +393,13 @@ class Mailer: return room_vars - async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): + async def get_notif_vars( + self, + notif: Dict[str, Any], + user_id: str, + notif_event: EventBase, + room_state_ids: StateMap[str], + ) -> Dict[str, Any]: results = await self.store.get_events_around( notif["room_id"], notif["event_id"], @@ -386,9 +425,11 @@ class Mailer: return ret - async def get_message_vars(self, notif, event, room_state_ids): - if event.type != EventTypes.Message: - return + async def get_message_vars( + self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str] + ) -> Optional[Dict[str, Any]]: + if event.type != EventTypes.Message and event.type != EventTypes.Encrypted: + return None sender_state_event_id = room_state_ids[("m.room.member", event.sender)] sender_state_event = await self.store.get_event(sender_state_event_id) @@ -399,10 +440,8 @@ class Mailer: # sender_hash % the number of default images to choose from sender_hash = string_ordinal_total(event.sender) - msgtype = event.content.get("msgtype") - ret = { - "msgtype": msgtype, + "event_type": event.type, "is_historical": event.event_id != notif["event_id"], "id": event.event_id, "ts": event.origin_server_ts, @@ -411,6 +450,14 @@ class Mailer: "sender_hash": sender_hash, } + # Encrypted messages don't have any additional useful information. + if event.type == EventTypes.Encrypted: + return ret + + msgtype = event.content.get("msgtype") + + ret["msgtype"] = msgtype + if msgtype == "m.text": self.add_text_message_vars(ret, event) elif msgtype == "m.image": @@ -421,7 +468,9 @@ class Mailer: return ret - def add_text_message_vars(self, messagevars, event): + def add_text_message_vars( + self, messagevars: Dict[str, Any], event: EventBase + ) -> None: msgformat = event.content.get("format") messagevars["format"] = msgformat @@ -434,15 +483,22 @@ class Mailer: elif body: messagevars["body_text_html"] = safe_text(body) - return messagevars - - def add_image_message_vars(self, messagevars, event): - messagevars["image_url"] = event.content["url"] - - return messagevars + def add_image_message_vars( + self, messagevars: Dict[str, Any], event: EventBase + ) -> None: + """ + Potentially add an image URL to the message variables. + """ + if "url" in event.content: + messagevars["image_url"] = event.content["url"] async def make_summary_text( - self, notifs_by_room, room_state_ids, notif_events, user_id, reason + self, + notifs_by_room: Dict[str, List[Dict[str, Any]]], + room_state_ids: Dict[str, StateMap[str]], + notif_events: Dict[str, EventBase], + user_id: str, + reason: Dict[str, Any], ): if len(notifs_by_room) == 1: # Only one room has new stuff @@ -455,16 +511,26 @@ class Mailer: self.store, room_state_ids[room_id], user_id, fallback_to_members=False ) - my_member_event_id = room_state_ids[room_id][("m.room.member", user_id)] - my_member_event = await self.store.get_event(my_member_event_id) - if my_member_event.content["membership"] == "invite": - inviter_member_event_id = room_state_ids[room_id][ - ("m.room.member", my_member_event.sender) - ] - inviter_member_event = await self.store.get_event( - inviter_member_event_id + # See if one of the notifs is an invite event for the user + invite_event = None + for n in notifs_by_room[room_id]: + ev = notif_events[n["event_id"]] + if ev.type == EventTypes.Member and ev.state_key == user_id: + if ev.content.get("membership") == Membership.INVITE: + invite_event = ev + break + + if invite_event: + inviter_member_event_id = room_state_ids[room_id].get( + ("m.room.member", invite_event.sender) ) - inviter_name = name_from_member_event(inviter_member_event) + inviter_name = invite_event.sender + if inviter_member_event_id: + inviter_member_event = await self.store.get_event( + inviter_member_event_id, allow_none=True + ) + if inviter_member_event: + inviter_name = name_from_member_event(inviter_member_event) if room_name is None: return self.email_subjects.invite_from_person % { @@ -559,7 +625,7 @@ class Mailer: "app": self.app_name, } - def make_room_link(self, room_id): + def make_room_link(self, room_id: str) -> str: if self.hs.config.email_riot_base_url: base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) elif self.app_name == "Vector": @@ -569,7 +635,7 @@ class Mailer: base_url = "https://matrix.to/#" return "%s/%s" % (base_url, room_id) - def make_notif_link(self, notif): + def make_notif_link(self, notif: Dict[str, str]) -> str: if self.hs.config.email_riot_base_url: return "%s/#/room/%s/%s" % ( self.hs.config.email_riot_base_url, @@ -585,7 +651,9 @@ class Mailer: else: return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"]) - def make_unsubscribe_link(self, user_id, app_id, email_address): + def make_unsubscribe_link( + self, user_id: str, app_id: str, email_address: str + ) -> str: params = { "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id), "app_id": app_id, @@ -599,7 +667,7 @@ class Mailer: ) -def safe_markup(raw_html): +def safe_markup(raw_html: str) -> jinja2.Markup: return jinja2.Markup( bleach.linkify( bleach.clean( @@ -614,7 +682,7 @@ def safe_markup(raw_html): ) -def safe_text(raw_text): +def safe_text(raw_text: str) -> jinja2.Markup: """ Process text: treat it as HTML but escape any tags (ie. just escape the HTML) then linkify it. @@ -634,7 +702,7 @@ def deduped_ordered_list(it: Iterable[T]) -> List[T]: return ret -def string_ordinal_total(s): +def string_ordinal_total(s: str) -> int: tot = 0 for c in s: tot += ord(c) diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index d8f4a453cd..7e50341d74 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py
@@ -15,8 +15,14 @@ import logging import re +from typing import TYPE_CHECKING, Dict, Iterable, Optional from synapse.api.constants import EventTypes +from synapse.events import EventBase +from synapse.types import StateMap + +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -28,25 +34,29 @@ ALL_ALONE = "Empty Room" async def calculate_room_name( - store, - room_state_ids, - user_id, - fallback_to_members=True, - fallback_to_single_member=True, -): + store: "DataStore", + room_state_ids: StateMap[str], + user_id: str, + fallback_to_members: bool = True, + fallback_to_single_member: bool = True, +) -> Optional[str]: """ Works out a user-facing name for the given room as per Matrix spec recommendations. Does not yet support internationalisation. Args: - room_state: Dictionary of the room's state + store: The data store to query. + room_state_ids: Dictionary of the room's state IDs. user_id: The ID of the user to whom the room name is being presented fallback_to_members: If False, return None instead of generating a name based on the room's members if the room has no title or aliases. + fallback_to_single_member: If False, return None instead of generating a + name based on the user who invited this user to the room if the room + has no title or aliases. Returns: - (string or None) A human readable name for the room. + A human readable name for the room, if possible. """ # does it have a name? if (EventTypes.Name, "") in room_state_ids: @@ -97,7 +107,7 @@ async def calculate_room_name( name_from_member_event(inviter_member_event), ) else: - return + return None else: return "Room Invite" @@ -150,19 +160,19 @@ async def calculate_room_name( else: return ALL_ALONE elif len(other_members) == 1 and not fallback_to_single_member: - return - else: - return descriptor_from_member_events(other_members) + return None + + return descriptor_from_member_events(other_members) -def descriptor_from_member_events(member_events): +def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str: """Get a description of the room based on the member events. Args: - member_events (Iterable[FrozenEvent]) + member_events: The events of a room. Returns: - str + The room description """ member_events = list(member_events) @@ -183,7 +193,7 @@ def descriptor_from_member_events(member_events): ) -def name_from_member_event(member_event): +def name_from_member_event(member_event: EventBase) -> str: if ( member_event.content and "displayname" in member_event.content @@ -193,12 +203,12 @@ def name_from_member_event(member_event): return member_event.state_key -def _state_as_two_level_dict(state): - ret = {} +def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]: + ret = {} # type: Dict[str, Dict[str, str]] for k, v in state.items(): ret.setdefault(k[0], {})[k[1]] = v return ret -def _looks_like_an_alias(string): +def _looks_like_an_alias(string: str) -> bool: return ALIAS_RE.match(string) is not None diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 709ace01e5..ba1877adcd 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py
@@ -16,11 +16,10 @@ import logging import re -from typing import Any, Dict, List, Pattern, Union +from typing import Any, Dict, List, Optional, Pattern, Tuple, Union from synapse.events import EventBase from synapse.types import UserID -from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -31,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") -def _room_member_count(ev, condition, room_member_count): +def _room_member_count( + ev: EventBase, condition: Dict[str, Any], room_member_count: int +) -> bool: return _test_ineq_condition(condition, room_member_count) -def _sender_notification_permission(ev, condition, sender_power_level, power_levels): +def _sender_notification_permission( + ev: EventBase, + condition: Dict[str, Any], + sender_power_level: int, + power_levels: Dict[str, Union[int, Dict[str, int]]], +) -> bool: notif_level_key = condition.get("key") if notif_level_key is None: return False notif_levels = power_levels.get("notifications", {}) + assert isinstance(notif_levels, dict) room_notif_level = notif_levels.get(notif_level_key, 50) return sender_power_level >= room_notif_level -def _test_ineq_condition(condition, number): +def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool: if "is" not in condition: return False m = INEQUALITY_EXPR.match(condition["is"]) @@ -111,7 +118,7 @@ class PushRuleEvaluatorForEvent: event: EventBase, room_member_count: int, sender_power_level: int, - power_levels: dict, + power_levels: Dict[str, Union[int, Dict[str, int]]], ): self._event = event self._room_member_count = room_member_count @@ -121,7 +128,9 @@ class PushRuleEvaluatorForEvent: # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) - def matches(self, condition: dict, user_id: str, display_name: str) -> bool: + def matches( + self, condition: Dict[str, Any], user_id: str, display_name: str + ) -> bool: if condition["kind"] == "event_match": return self._event_match(condition, user_id) elif condition["kind"] == "contains_display_name": @@ -174,20 +183,21 @@ class PushRuleEvaluatorForEvent: # Similar to _glob_matches, but do not treat display_name as a glob. r = regex_cache.get((display_name, False, True), None) if not r: - r = re.escape(display_name) - r = _re_word_boundary(r) - r = re.compile(r, flags=re.IGNORECASE) + r1 = re.escape(display_name) + r1 = _re_word_boundary(r1) + r = re.compile(r1, flags=re.IGNORECASE) regex_cache[(display_name, False, True)] = r - return r.search(body) + return bool(r.search(body)) - def _get_value(self, dotted_key: str) -> str: + def _get_value(self, dotted_key: str) -> Optional[str]: return self._value_cache.get(dotted_key, None) # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache(50000) -register_cache("cache", "regex_push_cache", regex_cache) +regex_cache = LruCache( + 50000, "regex_push_cache" +) # type: LruCache[Tuple[str, bool, bool], Pattern] def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: @@ -205,7 +215,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: if not r: r = _glob_to_re(glob, word_boundary) regex_cache[(glob, True, word_boundary)] = r - return r.search(value) + return bool(r.search(value)) except re.error: logger.warning("Failed to parse glob to regex: %r", glob) return False @@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str: return r"(^|\W)%s(\W|$)" % (r,) -def _flatten_dict(d, prefix=[], result=None): +def _flatten_dict( + d: Union[EventBase, dict], + prefix: Optional[List[str]] = None, + result: Optional[Dict[str, str]] = None, +) -> Dict[str, str]: + if prefix is None: + prefix = [] if result is None: result = {} for key, value in d.items(): diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..df34103224 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py
@@ -12,12 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict +from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage +from synapse.storage.databases.main import DataStore -async def get_badge_count(store, user_id): +async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int: invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) @@ -34,13 +37,21 @@ async def get_badge_count(store, user_id): room_id, user_id, last_unread_event_id ) ) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if notifs["notify_count"] else 0 + if notifs["notify_count"] == 0: + continue + + if group_by_room: + # return one badge count per conversation + badge += 1 + else: + # increment the badge count by the number of unread messages in the room + badge += notifs["notify_count"] return badge -async def get_context_for_event(storage: Storage, state_handler, ev, user_id): +async def get_context_for_event( + storage: Storage, ev: EventBase, user_id: str +) -> Dict[str, str]: ctx = {} room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id) diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 2a52e226e3..2aa7918fb4 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py
@@ -14,25 +14,31 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Callable, Dict, Optional +from synapse.push import Pusher, PusherConfig from synapse.push.emailpusher import EmailPusher +from synapse.push.httppusher import HttpPusher from synapse.push.mailer import Mailer -from .httppusher import HttpPusher +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) class PusherFactory: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config - self.pusher_types = {"http": HttpPusher} + self.pusher_types = { + "http": HttpPusher + } # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] logger.info("email enable notifs: %r", hs.config.email_enable_notifs) if hs.config.email_enable_notifs: - self.mailers = {} # app_name -> Mailer + self.mailers = {} # type: Dict[str, Mailer] self._notif_template_html = hs.config.email_notif_template_html self._notif_template_text = hs.config.email_notif_template_text @@ -41,16 +47,18 @@ class PusherFactory: logger.info("defined email pusher type") - def create_pusher(self, pusherdict): - kind = pusherdict["kind"] + def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]: + kind = pusher_config.kind f = self.pusher_types.get(kind, None) if not f: return None - logger.debug("creating %s pusher for %r", kind, pusherdict) - return f(self.hs, pusherdict) + logger.debug("creating %s pusher for %r", kind, pusher_config) + return f(self.hs, pusher_config) - def _create_email_pusher(self, _hs, pusherdict): - app_name = self._app_name_from_pusherdict(pusherdict) + def _create_email_pusher( + self, _hs: "HomeServer", pusher_config: PusherConfig + ) -> EmailPusher: + app_name = self._app_name_from_pusherdict(pusher_config) mailer = self.mailers.get(app_name) if not mailer: mailer = Mailer( @@ -60,10 +68,10 @@ class PusherFactory: template_text=self._notif_template_text, ) self.mailers[app_name] = mailer - return EmailPusher(self.hs, pusherdict, mailer) + return EmailPusher(self.hs, pusher_config, mailer) - def _app_name_from_pusherdict(self, pusherdict): - data = pusherdict["data"] + def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str: + data = pusher_config.data if isinstance(data, dict): brand = data.get("brand") diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 76150e117b..eed16dbfb5 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py
@@ -15,15 +15,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Union +from typing import TYPE_CHECKING, Dict, Iterable, Optional from prometheus_client import Gauge -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.push import PusherConfigException -from synapse.push.emailpusher import EmailPusher -from synapse.push.httppusher import HttpPusher +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) +from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.push.pusher import PusherFactory +from synapse.types import JsonDict, RoomStreamToken from synapse.util.async_helpers import concurrently_execute if TYPE_CHECKING: @@ -73,9 +75,9 @@ class PusherPool: self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering() # map from user id to app_id:pushkey to pusher - self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] + self.pushers = {} # type: Dict[str, Dict[str, Pusher]] - def start(self): + def start(self) -> None: """Starts the pushers off in a background process. """ if not self._should_start_pushers: @@ -85,52 +87,53 @@ class PusherPool: async def add_pusher( self, - user_id, - access_token, - kind, - app_id, - app_display_name, - device_display_name, - pushkey, - lang, - data, - profile_tag="", - ): + user_id: str, + access_token: Optional[int], + kind: str, + app_id: str, + app_display_name: str, + device_display_name: str, + pushkey: str, + lang: Optional[str], + data: JsonDict, + profile_tag: str = "", + ) -> Optional[Pusher]: """Creates a new pusher and adds it to the pool Returns: - EmailPusher|HttpPusher + The newly created pusher. """ time_now_msec = self.clock.time_msec() + # create the pusher setting last_stream_ordering to the current maximum + # stream ordering, so it will process pushes from this point onwards. + last_stream_ordering = self.store.get_room_max_stream_ordering() + # we try to create the pusher just to validate the config: it # will then get pulled out of the database, # recreated, added and started: this means we have only one # code path adding pushers. self.pusher_factory.create_pusher( - { - "id": None, - "user_name": user_id, - "kind": kind, - "app_id": app_id, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "pushkey": pushkey, - "ts": time_now_msec, - "lang": lang, - "data": data, - "last_stream_ordering": None, - "last_success": None, - "failing_since": None, - } + PusherConfig( + id=None, + user_name=user_id, + access_token=access_token, + profile_tag=profile_tag, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + ts=time_now_msec, + lang=lang, + data=data, + last_stream_ordering=last_stream_ordering, + last_success=None, + failing_since=None, + ) ) - # create the pusher setting last_stream_ordering to the current maximum - # stream ordering in event_push_actions, so it will process - # pushes from this point onwards. - last_stream_ordering = await self.store.get_latest_push_action_stream_ordering() - await self.store.add_pusher( user_id=user_id, access_token=access_token, @@ -150,51 +153,68 @@ class PusherPool: return pusher async def remove_pushers_by_app_id_and_pushkey_not_user( - self, app_id, pushkey, not_user_id - ): + self, app_id: str, pushkey: str, not_user_id: str + ) -> None: to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: - if p["user_name"] != not_user_id: + if p.user_name != not_user_id: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", app_id, pushkey, - p["user_name"], + p.user_name, ) - await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p.app_id, p.pushkey, p.user_name) - async def remove_pushers_by_access_token(self, user_id, access_tokens): + async def remove_pushers_by_access_token( + self, user_id: str, access_tokens: Iterable[int] + ) -> None: """Remove the pushers for a given user corresponding to a set of access_tokens. Args: - user_id (str): user to remove pushers for - access_tokens (Iterable[int]): access token *ids* to remove pushers - for + user_id: user to remove pushers for + access_tokens: access token *ids* to remove pushers for """ if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return tokens = set(access_tokens) for p in await self.store.get_pushers_by_user_id(user_id): - if p["access_token"] in tokens: + if p.access_token in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", - p["app_id"], - p["pushkey"], - p["user_name"], + p.app_id, + p.pushkey, + p.user_name, ) - await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p.app_id, p.pushkey, p.user_name) - async def on_new_notifications(self, max_stream_id: int): + def on_new_notifications(self, max_token: RoomStreamToken) -> None: if not self.pushers: # nothing to do here. return + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + max_stream_id = max_token.stream + if max_stream_id < self._last_room_stream_id_seen: # Nothing to do return + # We only start a new background process if necessary rather than + # optimistically (to cut down on overhead). + self._on_new_notifications(max_token) + + @wrap_as_background_process("on_new_notifications") + async def _on_new_notifications(self, max_token: RoomStreamToken) -> None: + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + max_stream_id = max_token.stream + prev_stream_id = self._last_room_stream_id_seen self._last_room_stream_id_seen = max_stream_id @@ -214,12 +234,14 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): - p.on_new_notifications(max_stream_id) + p.on_new_notifications(max_token) except Exception: logger.exception("Exception in pusher on_new_notifications") - async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + async def on_new_receipts( + self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str] + ) -> None: if not self.pushers: # nothing to do here. return @@ -247,28 +269,30 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - async def start_pusher_by_id(self, app_id, pushkey, user_id): + async def start_pusher_by_id( + self, app_id: str, pushkey: str, user_id: str + ) -> Optional[Pusher]: """Look up the details for the given pusher, and start it Returns: - EmailPusher|HttpPusher|None: The pusher started, if any + The pusher started, if any """ if not self._should_start_pushers: - return + return None if not self._pusher_shard_config.should_handle(self._instance_name, user_id): - return + return None resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) - pusher_dict = None + pusher_config = None for r in resultlist: - if r["user_name"] == user_id: - pusher_dict = r + if r.user_name == user_id: + pusher_config = r pusher = None - if pusher_dict: - pusher = await self._start_pusher(pusher_dict) + if pusher_config: + pusher = await self._start_pusher(pusher_config) return pusher @@ -283,44 +307,44 @@ class PusherPool: logger.info("Started pushers") - async def _start_pusher(self, pusherdict): + async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]: """Start the given pusher Args: - pusherdict (dict): dict with the values pulled from the db table + pusher_config: The pusher configuration with the values pulled from the db table Returns: - EmailPusher|HttpPusher + The newly created pusher or None. """ if not self._pusher_shard_config.should_handle( - self._instance_name, pusherdict["user_name"] + self._instance_name, pusher_config.user_name ): - return + return None try: - p = self.pusher_factory.create_pusher(pusherdict) + p = self.pusher_factory.create_pusher(pusher_config) except PusherConfigException as e: logger.warning( "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s", - pusherdict["id"], - pusherdict.get("user_name"), - pusherdict.get("app_id"), - pusherdict.get("pushkey"), + pusher_config.id, + pusher_config.user_name, + pusher_config.app_id, + pusher_config.pushkey, e, ) - return + return None except Exception: logger.exception( - "Couldn't start pusher id %i: caught Exception", pusherdict["id"], + "Couldn't start pusher id %i: caught Exception", pusher_config.id, ) - return + return None if not p: - return + return None - appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) + appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey) - byuser = self.pushers.setdefault(pusherdict["user_name"], {}) + byuser = self.pushers.setdefault(pusher_config.user_name, {}) if appid_pushkey in byuser: byuser[appid_pushkey].on_stop() byuser[appid_pushkey] = p @@ -330,8 +354,8 @@ class PusherPool: # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to # push. - user_id = pusherdict["user_name"] - last_stream_ordering = pusherdict["last_stream_ordering"] + user_id = pusher_config.user_name + last_stream_ordering = pusher_config.last_stream_ordering if last_stream_ordering: have_notifs = await self.store.get_if_maybe_push_in_range_for_user( user_id, last_stream_ordering @@ -345,7 +369,7 @@ class PusherPool: return p - async def remove_pusher(self, app_id, pushkey, user_id): + async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: appid_pushkey = "%s:%s" % (app_id, pushkey) byuser = self.pushers.get(user_id, {}) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 0ddead8a0f..c97e0df1f5 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py
@@ -40,6 +40,10 @@ logger = logging.getLogger(__name__) # Note that these both represent runtime dependencies (and the versions # installed are checked at runtime). # +# Also note that we replicate these constraints in the Synapse Dockerfile while +# pre-installing dependencies. If these constraints are updated here, the same +# change should be made in the Dockerfile. +# # [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. REQUIREMENTS = [ @@ -69,10 +73,7 @@ REQUIREMENTS = [ "msgpack>=0.5.2", "phonenumbers>=8.2.0", # we use GaugeHistogramMetric, which was added in prom-client 0.4.0. - # prom-client has a history of breaking backwards compatibility between - # minor versions (https://github.com/prometheus/client_python/issues/317), - # so we also pin the minor version. - "prometheus_client>=0.4.0,<0.9.0", + "prometheus_client>=0.4.0", # we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note: # Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33 # is out in November.) @@ -95,7 +96,11 @@ CONDITIONAL_REQUIREMENTS = { # python 3.5.2, as per https://github.com/itamarst/eliot/issues/418 'eliot<1.8.0;python_version<"3.5.3"', ], - "saml2": ["pysaml2>=4.5.0"], + "saml2": [ + # pysaml2 6.4.0 is incompatible with Python 3.5 (see https://github.com/IdentityPython/pysaml2/issues/749) + "pysaml2>=4.5.0,<6.4.0;python_version<'3.6'", + "pysaml2>=4.5.0;python_version>='3.6'", + ], "oidc": ["authlib>=0.14.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 64edadb624..1492ac922c 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py
@@ -92,7 +92,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): if self.CACHE: self.response_cache = ResponseCache( hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 - ) + ) # type: ResponseCache[str] # We reserve `instance_name` as a parameter to sending requests, so we # assert here that sub classes don't try and use the name. @@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): assert self.METHOD in ("PUT", "POST", "GET") + self._replication_secret = None + if hs.config.worker.worker_replication_secret: + self._replication_secret = hs.config.worker.worker_replication_secret + + def _check_auth(self, request) -> None: + # Get the authorization header. + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + + if len(auth_headers) > 1: + raise RuntimeError("Too many Authorization headers.") + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + received_secret = parts[1].decode("ascii") + if self._replication_secret == received_secret: + # Success! + return + + raise RuntimeError("Invalid Authorization header.") + @abc.abstractmethod async def _serialize_payload(**kwargs): """Static method that is called when creating a request. @@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) + replication_secret = None + if hs.config.worker.worker_replication_secret: + replication_secret = hs.config.worker.worker_replication_secret.encode( + "ascii" + ) + @trace(opname="outgoing_replication_request") @outgoing_gauge.track_inprogress() async def send_request(instance_name="master", **kwargs): @@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # the master, and so whether we should clean up or not. while True: headers = {} # type: Dict[bytes, List[bytes]] + # Add an authorization header, if configured. + if replication_secret: + headers[b"Authorization"] = [b"Bearer " + replication_secret] inject_active_span_byte_dict(headers, None, check_destination=False) try: result = await request_func(uri, data, headers=headers) @@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): """ url_args = list(self.PATH_ARGS) - handler = self._handle_request method = self.METHOD if self.CACHE: - handler = self._cached_handler # type: ignore url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) http_server.register_paths( - method, [pattern], handler, self.__class__.__name__, + method, [pattern], self._check_auth_and_handle, self.__class__.__name__, ) - def _cached_handler(self, request, txn_id, **kwargs): + def _check_auth_and_handle(self, request, **kwargs): """Called on new incoming requests when caching is enabled. Checks if there is a cached response for the request and returns that, otherwise calls `_handle_request` and caches its response. @@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # We just use the txn_id here, but we probably also want to use the # other PATH_ARGS as well. - assert self.CACHE + # Check the authorization headers before handling the request. + if self._replication_secret: + self._check_auth(request) + + if self.CACHE: + txn_id = kwargs.pop("txn_id") + + return self.response_cache.wrap( + txn_id, self._handle_request, request, **kwargs + ) - return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) + return self._handle_request(request, **kwargs) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 5393b9a9e7..7a0dbb5b1a 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py
@@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() self.storage = hs.get_storage() self.clock = hs.get_clock() - self.federation_handler = hs.get_handlers().federation_handler + self.federation_handler = hs.get_federation_handler() @staticmethod async def _serialize_payload(store, room_id, event_and_contexts, backfilled): @@ -254,20 +254,20 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): return 200, {} -class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): +class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): """Called to clean up any data in DB for a given room, ready for the server to join the room. Request format: - POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id + POST /_synapse/replication/store_room_on_outlier_membership/:room_id/:txn_id { "room_version": "1", } """ - NAME = "store_room_on_invite" + NAME = "store_room_on_outlier_membership" PATH_ARGS = ("room_id",) def __init__(self, hs): @@ -282,7 +282,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): async def _handle_request(self, request, room_id): content = parse_json_object_from_request(request) room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] - await self.store.maybe_store_room_on_invite(room_id, room_version) + await self.store.maybe_store_room_on_outlier_membership(room_id, room_version) return 200, {} @@ -291,4 +291,4 @@ def register_servlets(hs, http_server): ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) ReplicationCleanRoomRestServlet(hs).register(http_server) - ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server) + ReplicationStoreRoomOnOutlierMembershipRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 4c81e2d784..36071feb36 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py
@@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - async def _serialize_payload(user_id, device_id, initial_display_name, is_guest): + async def _serialize_payload( + user_id, device_id, initial_display_name, is_guest, is_appservice_ghost + ): """ Args: device_id (str|None): Device ID to use, if None a new one is @@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "device_id": device_id, "initial_display_name": initial_display_name, "is_guest": is_guest, + "is_appservice_ghost": is_appservice_ghost, } async def _handle_request(self, request, user_id): @@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] + is_appservice_ghost = content["is_appservice_ghost"] device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest + user_id, + device_id, + initial_display_name, + is_guest, + is_appservice_ghost=is_appservice_ghost, ) return 200, {"device_id": device_id, "access_token": access_token} diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 30680baee8..84e002f934 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py
@@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple + +from twisted.web.http import Request from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -47,21 +48,28 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) - self.federation_handler = hs.get_handlers().federation_handler + self.federation_handler = hs.get_federation_handler() self.store = hs.get_datastore() self.clock = hs.get_clock() @staticmethod - async def _serialize_payload( - requester, room_id, user_id, remote_room_hosts, content - ): + async def _serialize_payload( # type: ignore + requester: Requester, + room_id: str, + user_id: str, + remote_room_hosts: List[str], + content: JsonDict, + ) -> JsonDict: """ Args: - requester(Requester) - room_id (str) - user_id (str) - remote_room_hosts (list[str]): Servers to try and join via - content(dict): The event content to use for the join event + requester: The user making the request according to the access token + room_id: The ID of the room. + user_id: The ID of the user. + remote_room_hosts: Servers to try and join via + content: The event content to use for the join event + + Returns: + A dict representing the payload of the request. """ return { "requester": requester.serialize(), @@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request(self, request, room_id, user_id): + async def _handle_request( # type: ignore + self, request: Request, room_id: str, user_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -77,8 +87,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): requester = Requester.deserialize(self.store, content["requester"]) - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester logger.info("remote_join: %s into room: %s", user_id, room_id) @@ -119,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): txn_id: Optional[str], requester: Requester, content: JsonDict, - ): + ) -> JsonDict: """ Args: - invite_event_id: ID of the invite to be rejected - txn_id: optional transaction ID supplied by the client - requester: user making the rejection request, according to the access token - content: additional content to include in the rejection event. + invite_event_id: The ID of the invite to be rejected. + txn_id: Optional transaction ID supplied by the client + requester: User making the rejection request, according to the access token + content: Additional content to include in the rejection event. Normally an empty dict. + + Returns: + A dict representing the payload of the request. """ return { "txn_id": txn_id, @@ -134,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request(self, request, invite_event_id): + async def _handle_request( # type: ignore + self, request: Request, invite_event_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) txn_id = content["txn_id"] @@ -142,8 +156,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): requester = Requester.deserialize(self.store, content["requester"]) - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester # hopefully we're now on the master, so this won't recurse! event_id, stream_id = await self.member_handler.remote_reject_invite( @@ -176,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): self.distributor = hs.get_distributor() @staticmethod - async def _serialize_payload(room_id, user_id, change): + async def _serialize_payload( # type: ignore + room_id: str, user_id: str, change: str + ) -> JsonDict: """ Args: - room_id (str) - user_id (str) - change (str): "left" + room_id: The ID of the room. + user_id: The ID of the user. + change: "left" + + Returns: + A dict representing the payload of the request. """ assert change == "left" return {} - def _handle_request(self, request, room_id, user_id, change): + def _handle_request( # type: ignore + self, request: Request, room_id: str, user_id: str, change: str + ) -> Tuple[int, JsonDict]: logger.info("user membership change: %s in %s", user_id, room_id) user = UserID.from_string(user_id) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 9a3a694d5d..8fa104c8d3 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py
@@ -46,6 +46,12 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "ratelimit": true, "extra_users": [], } + + 200 OK + + { "stream_id": 12345, "event_id": "$abcdef..." } + + The returned event ID may not match the sent event if it was deduplicated. """ NAME = "send_event" @@ -109,18 +115,23 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester logger.info( "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - stream_id = await self.event_creation_handler.persist_and_notify_client_event( + event = await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return 200, {"stream_id": stream_id} + return ( + 200, + { + "stream_id": event.internal_metadata.stream_ordering, + "event_id": event.event_id, + }, + ) def register_servlets(hs, http_server): diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index eb74903d68..0d39a93ed2 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -12,21 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple +from synapse.storage.types import Connection from synapse.storage.util.id_generators import _load_current_id class SlavedIdTracker: - def __init__(self, db_conn, table, column, extra_tables=[], step=1): + def __init__( + self, + db_conn: Connection, + table: str, + column: str, + extra_tables: Optional[List[Tuple[str, str]]] = None, + step: int = 1, + ): self.step = step self._current = _load_current_id(db_conn, table, column, step) - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) + if extra_tables: + for table, column in extra_tables: + self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, instance_name, new_id): + def advance(self, instance_name: Optional[str], new_id: int): self._current = (max if self.step > 0 else min)(self._current, new_id) - def get_current_token(self): + def get_current_token(self) -> int: """ Returns: diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 1f8dafe7ea..0f5b7adef7 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py
@@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.lrucache import LruCache from ._base import BaseSlavedStore @@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 - ) + self.client_ip_last_seen = LruCache( + cache_name="client_ip_last_seen", keylen=4, max_size=50000 + ) # type: LruCache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) @@ -41,7 +41,7 @@ class SlavedClientIpStore(BaseSlavedStore): if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return - self.client_ip_last_seen.prefill(key, now) + self.client_ip_last_seen.set(key, now) self.hs.get_tcp_replication().send_user_ip( user_id, access_token, ip, user_agent, device_id, now diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index c418730ba8..045bd014da 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py
@@ -13,26 +13,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING from synapse.replication.tcp.streams import PushersStream from synapse.storage.database import DatabasePool from synapse.storage.databases.main.pusher import PusherWorkerStore +from synapse.storage.types import Connection from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self._pushers_id_gen = SlavedIdTracker( + self._pushers_id_gen = SlavedIdTracker( # type: ignore db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) - def get_pushers_stream_token(self): + def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token, rows + ) -> None: if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(instance_name, token) + self._pushers_id_gen.advance(instance_name, token) # type: ignore return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e165429cad..2618eb1e53 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -141,21 +141,25 @@ class ReplicationDataHandler: if row.type != EventsStreamEventRow.TypeId: continue assert isinstance(row, EventsStreamRow) + assert isinstance(row.data, EventsStreamEventRow) - event = await self.store.get_event( - row.data.event_id, allow_rejected=True - ) - if event.rejected_reason: + if row.data.rejected: continue extra_users = () # type: Tuple[UserID, ...] - if event.type == EventTypes.Member: - extra_users = (UserID.from_string(event.state_key),) + if row.data.type == EventTypes.Member and row.data.state_key: + extra_users = (UserID.from_string(row.data.state_key),) max_token = self.store.get_room_max_token() event_pos = PersistedEventPosition(instance_name, token) - self.notifier.on_new_room_event( - event, event_pos, max_token, extra_users + self.notifier.on_new_room_event_args( + event_pos=event_pos, + max_room_stream_token=max_token, + extra_users=extra_users, + room_id=row.data.room_id, + event_type=row.data.type, + state_key=row.data.state_key, + membership=row.data.membership, ) # Notify any waiting deferreds. The list is ordered by position so we @@ -191,6 +195,10 @@ class ReplicationDataHandler: async def on_position(self, stream_name: str, instance_name: str, token: int): self.store.process_replication_rows(stream_name, instance_name, token, []) + # We poke the generic "replication" notifier to wake anything up that + # may be streaming. + self.notifier.notify_replication() + def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 8cd47770c1..ac532ed588 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py
@@ -141,15 +141,23 @@ class RdataCommand(Command): class PositionCommand(Command): - """Sent by the server to tell the client the stream position without - needing to send an RDATA. + """Sent by an instance to tell others the stream position without needing to + send an RDATA. + + Two tokens are sent, the new position and the last position sent by the + instance (in an RDATA or other POSITION). The tokens are chosen so that *no* + rows were written by the instance between the `prev_token` and `new_token`. + (If an instance hasn't sent a position before then the new position can be + used for both.) Format:: - POSITION <stream_name> <instance_name> <token> + POSITION <stream_name> <instance_name> <prev_token> <new_token> - On receipt of a POSITION command clients should check if they have missed - any updates, and if so then fetch them out of band. + On receipt of a POSITION command instances should check if they have missed + any updates, and if so then fetch them out of band. Instances can check this + by comparing their view of the current token for the sending instance with + the included `prev_token`. The `<instance_name>` is the process that sent the command and is the source of the stream. @@ -157,18 +165,26 @@ class PositionCommand(Command): NAME = "POSITION" - def __init__(self, stream_name, instance_name, token): + def __init__(self, stream_name, instance_name, prev_token, new_token): self.stream_name = stream_name self.instance_name = instance_name - self.token = token + self.prev_token = prev_token + self.new_token = new_token @classmethod def from_line(cls, line): - stream_name, instance_name, token = line.split(" ", 2) - return cls(stream_name, instance_name, int(token)) + stream_name, instance_name, prev_token, new_token = line.split(" ", 3) + return cls(stream_name, instance_name, int(prev_token), int(new_token)) def to_line(self): - return " ".join((self.stream_name, self.instance_name, str(self.token))) + return " ".join( + ( + self.stream_name, + self.instance_name, + str(self.prev_token), + str(self.new_token), + ) + ) class ErrorCommand(_SimpleCommand): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..95e5502bf2 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -101,8 +101,9 @@ class ReplicationCommandHandler: self._streams_to_replicate = [] # type: List[Stream] for stream in self._streams.values(): - if stream.NAME == CachesStream.NAME: - # All workers can write to the cache invalidation stream. + if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME: + # All workers can write to the cache invalidation stream when + # using redis. self._streams_to_replicate.append(stream) continue @@ -251,10 +252,9 @@ class ReplicationCommandHandler: using TCP. """ if hs.config.redis.redis_enabled: - import txredisapi - from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, + lazyConnection, ) logger.info( @@ -271,7 +271,8 @@ class ReplicationCommandHandler: # connection after SUBSCRIBE is called). # First create the connection for sending commands. - outbound_redis_connection = txredisapi.lazyConnection( + outbound_redis_connection = lazyConnection( + reactor=hs.get_reactor(), host=hs.config.redis_host, port=hs.config.redis_port, password=hs.config.redis.redis_password, @@ -313,11 +314,14 @@ class ReplicationCommandHandler: # We respond with current position of all streams this instance # replicates. for stream in self.get_streams_to_replicate(): + # Note that we use the current token as the prev token here (rather + # than stream.last_token), as we can't be sure that there have been + # no rows written between last token and the current token (since we + # might be racing with the replication sending bg process). + current_token = stream.current_token(self._instance_name) self.send_command( PositionCommand( - stream.NAME, - self._instance_name, - stream.current_token(self._instance_name), + stream.NAME, self._instance_name, current_token, current_token, ) ) @@ -511,16 +515,16 @@ class ReplicationCommandHandler: # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates # between then and now. - missing_updates = cmd.token != current_token + missing_updates = cmd.prev_token != current_token while missing_updates: logger.info( "Fetching replication rows for '%s' between %i and %i", stream_name, current_token, - cmd.token, + cmd.new_token, ) (updates, current_token, missing_updates) = await stream.get_updates_since( - cmd.instance_name, current_token, cmd.token + cmd.instance_name, current_token, cmd.new_token ) # TODO: add some tests for this @@ -536,11 +540,11 @@ class ReplicationCommandHandler: [stream.parse_row(row) for row in rows], ) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token) # We've now caught up to position sent to us, notify handler. await self._replication_data_handler.on_position( - cmd.stream_name, cmd.instance_name, cmd.token + cmd.stream_name, cmd.instance_name, cmd.new_token ) self._streams_by_connection.setdefault(conn, set()).add(stream_name) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0b0d204e64..804da994ea 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -51,10 +51,11 @@ import fcntl import logging import struct from inspect import isawaitable -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from prometheus_client import Counter +from twisted.internet import task from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure @@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 - self.time_we_closed = None # When we requested the connection be closed + # When we requested the connection be closed + self.time_we_closed = None # type: Optional[int] - self.received_ping = False # Have we reecived a ping from the other side + self.received_ping = False # Have we received a ping from the other side self.state = ConnectionStates.CONNECTING @@ -165,13 +167,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.pending_commands = [] # type: List[Command] # The LoopingCall for sending pings. - self._send_ping_loop = None + self._send_ping_loop = None # type: Optional[task.LoopingCall] # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. ctx_name = "replication-conn-%s" % self.conn_id - self._logging_context = BackgroundProcessLoggingContext(ctx_name) - self._logging_context.request = ctx_name + self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name) def connectionMade(self): logger.info("[%s] Connection established", self.id()) diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index f225e533de..bc6ba709a7 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import txredisapi @@ -166,7 +166,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Args: cmd (Command) """ - run_as_background_process("send-cmd", self._async_send_command, cmd) + run_as_background_process( + "send-cmd", self._async_send_command, cmd, bg_start_span=False + ) async def _async_send_command(self, cmd: Command): """Encode a replication command and send it over our outbound connection""" @@ -228,3 +230,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): p.password = self.password return p + + +def lazyConnection( + reactor, + host: str = "localhost", + port: int = 6379, + dbid: Optional[int] = None, + reconnect: bool = True, + charset: str = "utf-8", + password: Optional[str] = None, + connectTimeout: Optional[int] = None, + replyTimeout: Optional[int] = None, + convertNumbers: bool = True, +) -> txredisapi.RedisProtocol: + """Equivalent to `txredisapi.lazyConnection`, except allows specifying a + reactor. + """ + + isLazy = True + poolsize = 1 + + uuid = "%s:%d" % (host, port) + factory = txredisapi.RedisFactory( + uuid, + dbid, + poolsize, + isLazy, + txredisapi.ConnectionHandler, + charset, + password, + replyTimeout, + convertNumbers, + ) + factory.continueTrying = reconnect + for x in range(poolsize): + reactor.connectTCP(host, port, factory, connectTimeout) + + return factory.handler diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 687984e7a8..1d4ceac0f1 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -23,7 +23,9 @@ from prometheus_client import Counter from twisted.internet.protocol import Factory from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.commands import PositionCommand from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol +from synapse.replication.tcp.streams import EventsStream from synapse.util.metrics import Measure stream_updates_counter = Counter( @@ -84,6 +86,23 @@ class ReplicationStreamer: # Set of streams to replicate. self.streams = self.command_handler.get_streams_to_replicate() + # If we have streams then we must have redis enabled or on master + assert ( + not self.streams + or hs.config.redis.redis_enabled + or not hs.config.worker.worker_app + ) + + # If we are replicating an event stream we want to periodically check if + # we should send updated POSITIONs. We do this as a looping call rather + # explicitly poking when the position advances (without new data to + # replicate) to reduce replication traffic (otherwise each writer would + # likely send a POSITION for each new event received over replication). + # + # Note that if the position hasn't advanced then we won't send anything. + if any(EventsStream.NAME == s.NAME for s in self.streams): + self.clock.looping_call(self.on_notifier_poke, 1000) + def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. @@ -91,13 +110,23 @@ class ReplicationStreamer: This should get called each time new data is available, even if it is currently being executed, so that nothing gets missed """ - if not self.command_handler.connected(): + if not self.command_handler.connected() or not self.streams: # Don't bother if nothing is listening. We still need to advance # the stream tokens otherwise they'll fall behind forever for stream in self.streams: stream.discard_updates_and_advance() return + # We check up front to see if anything has actually changed, as we get + # poked because of changes that happened on other instances. + if all( + stream.last_token == stream.current_token(self._instance_name) + for stream in self.streams + ): + return + + # If there are updates then we need to set this even if we're already + # looping, as the loop needs to know that he might need to loop again. self.pending_updates = True if self.is_looping: @@ -136,6 +165,8 @@ class ReplicationStreamer: self._replication_torture_level / 1000.0 ) + last_token = stream.last_token + logger.debug( "Getting stream: %s: %s -> %s", stream.NAME, @@ -159,6 +190,30 @@ class ReplicationStreamer: ) stream_updates_counter.labels(stream.NAME).inc(len(updates)) + else: + # The token has advanced but there is no data to + # send, so we send a `POSITION` to inform other + # workers of the updated position. + if stream.NAME == EventsStream.NAME: + # XXX: We only do this for the EventStream as it + # turns out that e.g. account data streams share + # their "current token" with each other, meaning + # that it is *not* safe to send a POSITION. + logger.info( + "Sending position: %s -> %s", + stream.NAME, + current_token, + ) + self.command_handler.send_command( + PositionCommand( + stream.NAME, + self._instance_name, + last_token, + current_token, + ) + ) + continue + # Some streams return multiple rows with the same stream IDs, # we need to make sure they get sent out in batches. We do # this by setting the current token to all but the last of diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 54dccd15a6..61b282ab2d 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -240,13 +240,18 @@ class BackfillStream(Stream): ROW_TYPE = BackfillStreamRow def __init__(self, hs): - store = hs.get_datastore() + self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_current_backfill_token), - store.get_all_new_backfill_event_rows, + self._current_token, + self.store.get_all_new_backfill_event_rows, ) + def _current_token(self, instance_name: str) -> int: + # The backfill stream over replication operates on *positive* numbers, + # which means we need to negate it. + return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name) + class PresenceStream(Stream): PresenceStreamRow = namedtuple( diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ccc7ca30d8..86a62b71eb 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py
@@ -15,12 +15,15 @@ # limitations under the License. import heapq from collections.abc import Iterable -from typing import List, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type import attr from ._base import Stream, StreamUpdateResult, Token +if TYPE_CHECKING: + from synapse.server import HomeServer + """Handling of the 'events' replication stream This stream contains rows of various types. Each row therefore contains a 'type' @@ -81,12 +84,14 @@ class BaseEventsStreamRow: class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional - relates_to = attr.ib() # str, optional + event_id = attr.ib(type=str) + room_id = attr.ib(type=str) + type = attr.ib(type=str) + state_key = attr.ib(type=Optional[str]) + redacts = attr.ib(type=Optional[str]) + relates_to = attr.ib(type=Optional[str]) + membership = attr.ib(type=Optional[str]) + rejected = attr.ib(type=bool) @attr.s(slots=True, frozen=True) @@ -113,7 +118,7 @@ class EventsStream(Stream): NAME = "events" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -155,7 +160,7 @@ class EventsStream(Stream): # now we fetch up to that many rows from the events table event_rows = await self._store.get_all_new_forward_event_rows( - from_token, current_token, target_row_count + instance_name, from_token, current_token, target_row_count ) # type: List[Tuple] # we rely on get_all_new_forward_event_rows strictly honouring the limit, so @@ -180,7 +185,7 @@ class EventsStream(Stream): upper_limit, state_rows_limited, ) = await self._store.get_all_updated_current_state_deltas( - from_token, upper_limit, target_row_count + instance_name, from_token, upper_limit, target_row_count ) limited = limited or state_rows_limited @@ -189,7 +194,7 @@ class EventsStream(Stream): # not to bother with the limit. ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( - from_token, upper_limit + instance_name, from_token, upper_limit ) # type: List[Tuple] # we now need to turn the raw database rows returned into tuples suitable diff --git a/synapse/res/templates/notif.html b/synapse/res/templates/notif.html
index 1a6c70b562..0aaef97df8 100644 --- a/synapse/res/templates/notif.html +++ b/synapse/res/templates/notif.html
@@ -1,41 +1,47 @@ -{% for message in notif.messages %} +{%- for message in notif.messages %} <tr class="{{ "historical_message" if message.is_historical else "message" }}"> <td class="sender_avatar"> - {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} - {% if message.sender_avatar_url %} + {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} + {%- if message.sender_avatar_url %} <img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" /> - {% else %} - {% if message.sender_hash % 3 == 0 %} + {%- else %} + {%- if message.sender_hash % 3 == 0 %} <img class="sender_avatar" src="https://riot.im/img/external/avatar-1.png" /> - {% elif message.sender_hash % 3 == 1 %} + {%- elif message.sender_hash % 3 == 1 %} <img class="sender_avatar" src="https://riot.im/img/external/avatar-2.png" /> - {% else %} + {%- else %} <img class="sender_avatar" src="https://riot.im/img/external/avatar-3.png" /> - {% endif %} - {% endif %} - {% endif %} + {%- endif %} + {%- endif %} + {%- endif %} </td> <td class="message_contents"> - {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} - <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div> - {% endif %} + {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} + <div class="sender_name">{%- if message.msgtype == "m.emote" %}*{%- endif %} {{ message.sender_name }}</div> + {%- endif %} <div class="message_body"> - {% if message.msgtype == "m.text" %} - {{ message.body_text_html }} - {% elif message.msgtype == "m.emote" %} - {{ message.body_text_html }} - {% elif message.msgtype == "m.notice" %} - {{ message.body_text_html }} - {% elif message.msgtype == "m.image" %} - <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" /> - {% elif message.msgtype == "m.file" %} - <span class="filename">{{ message.body_text_plain }}</span> - {% endif %} + {%- if message.event_type == "m.room.encrypted" %} + An encrypted message. + {%- elif message.event_type == "m.room.message" %} + {%- if message.msgtype == "m.text" %} + {{ message.body_text_html }} + {%- elif message.msgtype == "m.emote" %} + {{ message.body_text_html }} + {%- elif message.msgtype == "m.notice" %} + {{ message.body_text_html }} + {%- elif message.msgtype == "m.image" and message.image_url %} + <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" /> + {%- elif message.msgtype == "m.file" %} + <span class="filename">{{ message.body_text_plain }}</span> + {%- else %} + A message with unrecognised content. + {%- endif %} + {%- endif %} </div> </td> <td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td> </tr> -{% endfor %} +{%- endfor %} <tr class="notif_link"> <td></td> <td> diff --git a/synapse/res/templates/notif.txt b/synapse/res/templates/notif.txt
index a37bee9833..1ee7da3c50 100644 --- a/synapse/res/templates/notif.txt +++ b/synapse/res/templates/notif.txt
@@ -1,16 +1,22 @@ -{% for message in notif.messages %} -{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }}) -{% if message.msgtype == "m.text" %} +{%- for message in notif.messages %} +{%- if message.event_type == "m.room.encrypted" %} +An encrypted message. +{%- elif message.event_type == "m.room.message" %} +{%- if message.msgtype == "m.emote" %}* {%- endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }}) +{%- if message.msgtype == "m.text" %} {{ message.body_text_plain }} -{% elif message.msgtype == "m.emote" %} +{%- elif message.msgtype == "m.emote" %} {{ message.body_text_plain }} -{% elif message.msgtype == "m.notice" %} +{%- elif message.msgtype == "m.notice" %} {{ message.body_text_plain }} -{% elif message.msgtype == "m.image" %} +{%- elif message.msgtype == "m.image" %} {{ message.body_text_plain }} -{% elif message.msgtype == "m.file" %} +{%- elif message.msgtype == "m.file" %} {{ message.body_text_plain }} -{% endif %} -{% endfor %} +{%- else %} +A message with unrecognised content. +{%- endif %} +{%- endif %} +{%- endfor %} View {{ room.title }} at {{ notif.link }} diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html
index a2dfeb9e9f..27d4182790 100644 --- a/synapse/res/templates/notif_mail.html +++ b/synapse/res/templates/notif_mail.html
@@ -2,8 +2,8 @@ <html lang="en"> <head> <style type="text/css"> - {% include 'mail.css' without context %} - {% include "mail-%s.css" % app_name ignore missing without context %} + {%- include 'mail.css' without context %} + {%- include "mail-%s.css" % app_name ignore missing without context %} </style> </head> <body> @@ -18,21 +18,21 @@ <div class="summarytext">{{ summary_text }}</div> </td> <td class="logo"> - {% if app_name == "Riot" %} + {%- if app_name == "Riot" %} <img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/> - {% elif app_name == "Vector" %} + {%- elif app_name == "Vector" %} <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> - {% elif app_name == "Element" %} + {%- elif app_name == "Element" %} <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/> - {% else %} + {%- else %} <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> - {% endif %} + {%- endif %} </td> </tr> </table> - {% for room in rooms %} - {% include 'room.html' with context %} - {% endfor %} + {%- for room in rooms %} + {%- include 'room.html' with context %} + {%- endfor %} <div class="footer"> <a href="{{ unsubscribe_link }}">Unsubscribe</a> <br/> @@ -41,12 +41,12 @@ Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because an event was received at {{ reason.received_at|format_ts("%c") }} which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago, - {% if reason.last_sent_ts %} + {%- if reason.last_sent_ts %} and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }}, which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago. - {% else %} + {%- else %} and we don't have a last time we sent a mail for this room. - {% endif %} + {%- endif %} </div> </div> </td> diff --git a/synapse/res/templates/notif_mail.txt b/synapse/res/templates/notif_mail.txt
index 24843042a5..df3c253979 100644 --- a/synapse/res/templates/notif_mail.txt +++ b/synapse/res/templates/notif_mail.txt
@@ -2,9 +2,9 @@ Hi {{ user_display_name }}, {{ summary_text }} -{% for room in rooms %} -{% include 'room.txt' with context %} -{% endfor %} +{%- for room in rooms %} +{%- include 'room.txt' with context %} +{%- endfor %} You can disable these notifications at {{ unsubscribe_link }} diff --git a/synapse/res/templates/room.html b/synapse/res/templates/room.html
index b8525fef88..4fc6f6ac9b 100644 --- a/synapse/res/templates/room.html +++ b/synapse/res/templates/room.html
@@ -1,23 +1,23 @@ <table class="room"> <tr class="room_header"> <td class="room_avatar"> - {% if room.avatar_url %} + {%- if room.avatar_url %} <img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" /> - {% else %} - {% if room.hash % 3 == 0 %} + {%- else %} + {%- if room.hash % 3 == 0 %} <img alt="" src="https://riot.im/img/external/avatar-1.png" /> - {% elif room.hash % 3 == 1 %} + {%- elif room.hash % 3 == 1 %} <img alt="" src="https://riot.im/img/external/avatar-2.png" /> - {% else %} + {%- else %} <img alt="" src="https://riot.im/img/external/avatar-3.png" /> - {% endif %} - {% endif %} + {%- endif %} + {%- endif %} </td> <td class="room_name" colspan="2"> {{ room.title }} </td> </tr> - {% if room.invite %} + {%- if room.invite %} <tr> <td></td> <td> @@ -25,9 +25,9 @@ </td> <td></td> </tr> - {% else %} - {% for notif in room.notifs %} - {% include 'notif.html' with context %} - {% endfor %} - {% endif %} + {%- else %} + {%- for notif in room.notifs %} + {%- include 'notif.html' with context %} + {%- endfor %} + {%- endif %} </table> diff --git a/synapse/res/templates/room.txt b/synapse/res/templates/room.txt
index 84648c710e..df841e9e6f 100644 --- a/synapse/res/templates/room.txt +++ b/synapse/res/templates/room.txt
@@ -1,9 +1,9 @@ {{ room.title }} -{% if room.invite %} +{%- if room.invite %} You've been invited, join at {{ room.link }} -{% else %} - {% for notif in room.notifs %} - {% include 'notif.txt' with context %} - {% endfor %} -{% endif %} +{%- else %} + {%- for notif in room.notifs %} + {%- include 'notif.txt' with context %} + {%- endfor %} +{%- endif %} diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html new file mode 100644
index 0000000000..37ea8bb6d8 --- /dev/null +++ b/synapse/res/username_picker/index.html
@@ -0,0 +1,19 @@ +<!DOCTYPE html> +<html lang="en"> + <head> + <title>Synapse Login</title> + <link rel="stylesheet" href="style.css" type="text/css" /> + </head> + <body> + <div class="card"> + <form method="post" class="form__input" id="form" action="submit"> + <label for="field-username">Please pick your username:</label> + <input type="text" name="username" id="field-username" autofocus=""> + <input type="submit" class="button button--full-width" id="button-submit" value="Submit"> + </form> + <!-- this is used for feedback --> + <div role=alert class="tooltip hidden" id="message"></div> + <script src="script.js"></script> + </div> + </body> +</html> diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js new file mode 100644
index 0000000000..416a7c6f41 --- /dev/null +++ b/synapse/res/username_picker/script.js
@@ -0,0 +1,95 @@ +let inputField = document.getElementById("field-username"); +let inputForm = document.getElementById("form"); +let submitButton = document.getElementById("button-submit"); +let message = document.getElementById("message"); + +// Submit username and receive response +function showMessage(messageText) { + // Unhide the message text + message.classList.remove("hidden"); + + message.textContent = messageText; +}; + +function doSubmit() { + showMessage("Success. Please wait a moment for your browser to redirect."); + + // remove the event handler before re-submitting the form. + delete inputForm.onsubmit; + inputForm.submit(); +} + +function onResponse(response) { + // Display message + showMessage(response); + + // Enable submit button and input field + submitButton.classList.remove('button--disabled'); + submitButton.value = "Submit"; +}; + +let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]"); +function usernameIsValid(username) { + return !allowedUsernameCharacters.test(username); +} +let allowedCharactersString = "lowercase letters, digits, ., _, -, /, ="; + +function buildQueryString(params) { + return Object.keys(params) + .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k])) + .join('&'); +} + +function submitUsername(username) { + if(username.length == 0) { + onResponse("Please enter a username."); + return; + } + if(!usernameIsValid(username)) { + onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString); + return; + } + + // if this browser doesn't support fetch, skip the availability check. + if(!window.fetch) { + doSubmit(); + return; + } + + let check_uri = 'check?' + buildQueryString({"username": username}); + fetch(check_uri, { + // include the cookie + "credentials": "same-origin", + }).then((response) => { + if(!response.ok) { + // for non-200 responses, raise the body of the response as an exception + return response.text().then((text) => { throw text; }); + } else { + return response.json(); + } + }).then((json) => { + if(json.error) { + throw json.error; + } else if(json.available) { + doSubmit(); + } else { + onResponse("This username is not available, please choose another."); + } + }).catch((err) => { + onResponse("Error checking username availability: " + err); + }); +} + +function clickSubmit() { + event.preventDefault(); + if(submitButton.classList.contains('button--disabled')) { return; } + + // Disable submit button and input field + submitButton.classList.add('button--disabled'); + + // Submit username + submitButton.value = "Checking..."; + submitUsername(inputField.value); +}; + +inputForm.onsubmit = clickSubmit; diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css new file mode 100644
index 0000000000..745bd4c684 --- /dev/null +++ b/synapse/res/username_picker/style.css
@@ -0,0 +1,27 @@ +input[type="text"] { + font-size: 100%; + background-color: #ededf0; + border: 1px solid #fff; + border-radius: .2em; + padding: .5em .9em; + display: block; + width: 26em; +} + +.button--disabled { + border-color: #fff; + background-color: transparent; + color: #000; + text-transform: none; +} + +.hidden { + display: none; +} + +.tooltip { + background-color: #f9f9fa; + padding: 1em; + margin: 1em 0; +} + diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 57cac22252..6f7dc06503 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py
@@ -21,17 +21,16 @@ import synapse from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.rest.admin._base import ( - admin_patterns, - assert_requester_is_admin, - historical_admin_path_patterns, -) +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin.devices import ( DeleteDevicesRestServlet, DeviceRestServlet, DevicesRestServlet, ) -from synapse.rest.admin.event_reports import EventReportsRestServlet +from synapse.rest.admin.event_reports import ( + EventReportDetailRestServlet, + EventReportsRestServlet, +) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet @@ -39,24 +38,30 @@ from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, + MakeRoomAdminRestServlet, RoomMembersRestServlet, RoomRestServlet, ShutdownRoomRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet +from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, DeactivateAccountRestServlet, + PushersRestServlet, ResetPasswordRestServlet, SearchUsersRestServlet, UserAdminServlet, + UserMediaRestServlet, UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, UsersRestServlet, UsersRestServletV2, + UserTokenRestServlet, WhoisRestServlet, ) +from synapse.types import RoomStreamToken from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) @@ -76,7 +81,7 @@ class VersionServlet(RestServlet): class PurgeHistoryRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns( + PATTERNS = admin_patterns( "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?" ) @@ -109,7 +114,9 @@ class PurgeHistoryRestServlet(RestServlet): if event.room_id != room_id: raise SynapseError(400, "Event is for wrong room.") - room_token = await self.store.get_topological_token_for_event(event_id) + room_token = RoomStreamToken( + event.depth, event.internal_metadata.stream_ordering + ) token = await room_token.to_string(self.store) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) @@ -159,9 +166,7 @@ class PurgeHistoryRestServlet(RestServlet): class PurgeHistoryStatusRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns( - "/purge_history_status/(?P<purge_id>[^/]+)" - ) + PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)") def __init__(self, hs): """ @@ -212,13 +217,19 @@ def register_servlets(hs, http_server): SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) + UserMediaRestServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server) + UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) DeviceRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeleteDevicesRestServlet(hs).register(http_server) + UserMediaStatisticsRestServlet(hs).register(http_server) + EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) + PushersRestServlet(hs).register(http_server) + MakeRoomAdminRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index db9fea263a..e09234c644 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py
@@ -22,28 +22,6 @@ from synapse.api.errors import AuthError from synapse.types import UserID -def historical_admin_path_patterns(path_regex): - """Returns the list of patterns for an admin endpoint, including historical ones - - This is a backwards-compatibility hack. Previously, the Admin API was exposed at - various paths under /_matrix/client. This function returns a list of patterns - matching those paths (as well as the new one), so that existing scripts which rely - on the endpoints being available there are not broken. - - Note that this should only be used for existing endpoints: new ones should just - register for the /_synapse/admin path. - """ - return [ - re.compile(prefix + path_regex) - for prefix in ( - "^/_synapse/admin/v1", - "^/_matrix/client/api/v1/admin", - "^/_matrix/client/unstable/admin", - "^/_matrix/client/r0/admin", - ) - ] - - def admin_patterns(path_regex: str, version: str = "v1"): """Returns the list of patterns for an admin endpoint diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index a163863322..ffd3aa38f7 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py
@@ -119,7 +119,7 @@ class DevicesRestServlet(RestServlet): raise NotFoundError("Unknown user") devices = await self.device_handler.get_devices_by_user(target_user.to_string()) - return 200, {"devices": devices} + return 200, {"devices": devices, "total": len(devices)} class DeleteDevicesRestServlet(RestServlet): diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index 5b8d0594cd..fd482f0e32 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py
@@ -15,7 +15,7 @@ import logging -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin @@ -86,3 +86,47 @@ class EventReportsRestServlet(RestServlet): ret["next_token"] = start + len(event_reports) return 200, ret + + +class EventReportDetailRestServlet(RestServlet): + """ + Get a specific reported event that is known to the homeserver. Results are returned + in a dictionary containing report information. + The requester must have administrator access in Synapse. + + GET /_synapse/admin/v1/event_reports/<report_id> + returns: + 200 OK with details report if success otherwise an error. + + Args: + The parameter `report_id` is the ID of the event report in the database. + Returns: + JSON blob of information about the event report + """ + + PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, report_id): + await assert_requester_is_admin(self.auth, request) + + message = ( + "The report_id parameter must be a string representing a positive integer." + ) + try: + report_id = int(report_id) + except ValueError: + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + + if report_id < 0: + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + + ret = await self.store.get_event_report(report_id) + if not ret: + raise NotFoundError("Event report not found") + + return 200, ret diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index 0b54ca09f4..d0c86b204a 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py
@@ -16,10 +16,7 @@ import logging from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet -from synapse.rest.admin._base import ( - assert_user_is_admin, - historical_admin_path_patterns, -) +from synapse.rest.admin._base import admin_patterns, assert_user_is_admin logger = logging.getLogger(__name__) @@ -28,7 +25,7 @@ class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups """ - PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)") + PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)") def __init__(self, hs): self.group_server = hs.get_groups_server_handler() diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index ee75095c0e..c82b4f87d6 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py
@@ -16,12 +16,12 @@ import logging -from synapse.api.errors import AuthError -from synapse.http.servlet import RestServlet, parse_integer +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.http.servlet import RestServlet, parse_boolean, parse_integer from synapse.rest.admin._base import ( + admin_patterns, assert_requester_is_admin, assert_user_is_admin, - historical_admin_path_patterns, ) logger = logging.getLogger(__name__) @@ -33,10 +33,10 @@ class QuarantineMediaInRoom(RestServlet): """ PATTERNS = ( - historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine") + admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine") + # This path kept around for legacy reasons - historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)") + admin_patterns("/quarantine_media/(?P<room_id>[^/]+)") ) def __init__(self, hs): @@ -62,9 +62,7 @@ class QuarantineMediaByUser(RestServlet): this server. """ - PATTERNS = historical_admin_path_patterns( - "/user/(?P<user_id>[^/]+)/media/quarantine" - ) + PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine") def __init__(self, hs): self.store = hs.get_datastore() @@ -89,7 +87,7 @@ class QuarantineMediaByID(RestServlet): it via this server. """ - PATTERNS = historical_admin_path_patterns( + PATTERNS = admin_patterns( "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" ) @@ -115,7 +113,7 @@ class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ - PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media") + PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media") def __init__(self, hs): self.store = hs.get_datastore() @@ -133,7 +131,7 @@ class ListMediaInRoom(RestServlet): class PurgeMediaCacheRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/purge_media_cache") + PATTERNS = admin_patterns("/purge_media_cache") def __init__(self, hs): self.media_repository = hs.get_media_repository() @@ -150,6 +148,80 @@ class PurgeMediaCacheRestServlet(RestServlet): return 200, ret +class DeleteMediaByID(RestServlet): + """Delete local media by a given ID. Removes it from this server. + """ + + PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)") + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.server_name = hs.hostname + self.media_repository = hs.get_media_repository() + + async def on_DELETE(self, request, server_name: str, media_id: str): + await assert_requester_is_admin(self.auth, request) + + if self.server_name != server_name: + raise SynapseError(400, "Can only delete local media") + + if await self.store.get_local_media(media_id) is None: + raise NotFoundError("Unknown media") + + logging.info("Deleting local media by ID: %s", media_id) + + deleted_media, total = await self.media_repository.delete_local_media(media_id) + return 200, {"deleted_media": deleted_media, "total": total} + + +class DeleteMediaByDateSize(RestServlet): + """Delete local media and local copies of remote media by + timestamp and size. + """ + + PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete") + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.server_name = hs.hostname + self.media_repository = hs.get_media_repository() + + async def on_POST(self, request, server_name: str): + await assert_requester_is_admin(self.auth, request) + + before_ts = parse_integer(request, "before_ts", required=True) + size_gt = parse_integer(request, "size_gt", default=0) + keep_profiles = parse_boolean(request, "keep_profiles", default=True) + + if before_ts < 0: + raise SynapseError( + 400, + "Query parameter before_ts must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + if size_gt < 0: + raise SynapseError( + 400, + "Query parameter size_gt must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if self.server_name != server_name: + raise SynapseError(400, "Can only delete local media") + + logging.info( + "Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s" + % (before_ts, size_gt, keep_profiles) + ) + + deleted_media, total = await self.media_repository.delete_old_local_media( + before_ts, size_gt, keep_profiles + ) + return 200, {"deleted_media": deleted_media, "total": total} + + def register_servlets_for_media_repo(hs, http_server): """ Media repo specific APIs. @@ -159,3 +231,5 @@ def register_servlets_for_media_repo(hs, http_server): QuarantineMediaByID(hs).register(http_server) QuarantineMediaByUser(hs).register(http_server) ListMediaInRoom(hs).register(http_server) + DeleteMediaByID(hs).register(http_server) + DeleteMediaByDateSize(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 09726d52d6..ab7cc9102a 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -14,10 +14,10 @@ # limitations under the License. import logging from http import HTTPStatus -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple -from synapse.api.constants import EventTypes, JoinRules -from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -25,14 +25,18 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, assert_user_is_admin, - historical_admin_path_patterns, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import RoomAlias, RoomID, UserID, create_requester +from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester + +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -44,14 +48,16 @@ class ShutdownRoomRestServlet(RestServlet): joined to the new room. """ - PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)") + PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_shutdown_handler = hs.get_room_shutdown_handler() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -71,25 +77,31 @@ class ShutdownRoomRestServlet(RestServlet): class DeleteRoomRestServlet(RestServlet): - """Delete a room from server. It is a combination and improvement of - shut down and purge room. + """Delete a room from server. + + It is a combination and improvement of shutdown and purge room. + Shuts down a room by removing all local users from the room. Blocking all future invites and joins to the room is optional. + If desired any local aliases will be repointed to a new room - created by `new_room_user_id` and kicked users will be auto + created by `new_room_user_id` and kicked users will be auto- joined to the new room. - It will remove all trace of a room from the database. + + If 'purge' is true, it will remove all traces of a room from the database. """ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -111,6 +123,14 @@ class DeleteRoomRestServlet(RestServlet): Codes.BAD_JSON, ) + force_purge = content.get("force_purge", False) + if not isinstance(force_purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'force_purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + ret = await self.room_shutdown_handler.shutdown_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), @@ -122,7 +142,7 @@ class DeleteRoomRestServlet(RestServlet): # Purge room if purge: - await self.pagination_handler.purge_room(room_id) + await self.pagination_handler.purge_room(room_id, force=force_purge) return (200, ret) @@ -135,12 +155,12 @@ class ListRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -225,19 +245,24 @@ class RoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room_with_stats(room_id) if not ret: raise NotFoundError("Room not found") - return 200, ret + members = await self.store.get_users_in_room(room_id) + ret["joined_local_devices"] = await self.store.count_devices_by_users(members) + + return (200, ret) class RoomMembersRestServlet(RestServlet): @@ -247,12 +272,14 @@ class RoomMembersRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room(room_id) @@ -269,14 +296,16 @@ class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_member_handler = hs.get_room_member_handler() - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() - async def on_POST(self, request, room_identifier): + async def on_POST( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -303,13 +332,14 @@ class JoinRoomAliasServlet(RestServlet): handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) - room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - fake_requester = create_requester(target_user) + fake_requester = create_requester( + target_user, authenticated_entity=requester.authenticated_entity + ) # send invite if room has "JoinRules.INVITE" room_state = await self.state_handler.get_current_state(room_id) @@ -338,3 +368,134 @@ class JoinRoomAliasServlet(RestServlet): ) return 200, {"room_id": room_id} + + +class MakeRoomAdminRestServlet(RestServlet): + """Allows a server admin to get power in a room if a local user has power in + a room. Will also invite the user if they're not in the room and it's a + private room. Can specify another user (rather than the admin user) to be + granted power, e.g.: + + POST/_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin + { + "user_id": "@foo:example.com" + } + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self.state_handler = hs.get_state_handler() + self.is_mine_id = hs.is_mine_id + + async def on_POST(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + content = parse_json_object_from_request(request, allow_empty_body=True) + + # Resolve to a room ID, if necessary. + if RoomID.is_valid(room_identifier): + room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + # Which user to grant room admin rights to. + user_to_add = content.get("user_id", requester.user.to_string()) + + # Figure out which local users currently have power in the room, if any. + room_state = await self.state_handler.get_current_state(room_id) + if not room_state: + raise SynapseError(400, "Server not in room") + + create_event = room_state[(EventTypes.Create, "")] + power_levels = room_state.get((EventTypes.PowerLevels, "")) + + if power_levels is not None: + # We pick the local user with the highest power. + user_power = power_levels.content.get("users", {}) + admin_users = [ + user_id for user_id in user_power if self.is_mine_id(user_id) + ] + admin_users.sort(key=lambda user: user_power[user]) + + if not admin_users: + raise SynapseError(400, "No local admin user in room") + + admin_user_id = admin_users[-1] + + pl_content = power_levels.content + else: + # If there is no power level events then the creator has rights. + pl_content = {} + admin_user_id = create_event.sender + if not self.is_mine_id(admin_user_id): + raise SynapseError( + 400, "No local admin user in room", + ) + + # Grant the user power equal to the room admin by attempting to send an + # updated power level event. + new_pl_content = dict(pl_content) + new_pl_content["users"] = dict(pl_content.get("users", {})) + new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id] + + fake_requester = create_requester( + admin_user_id, authenticated_entity=requester.authenticated_entity, + ) + + try: + await self.event_creation_handler.create_and_send_nonmember_event( + fake_requester, + event_dict={ + "content": new_pl_content, + "sender": admin_user_id, + "type": EventTypes.PowerLevels, + "state_key": "", + "room_id": room_id, + }, + ) + except AuthError: + # The admin user we found turned out not to have enough power. + raise SynapseError( + 400, "No local admin user in room with power to update power levels." + ) + + # Now we check if the user we're granting admin rights to is already in + # the room. If not and it's not a public room we invite them. + member_event = room_state.get((EventTypes.Member, user_to_add)) + is_joined = False + if member_event: + is_joined = member_event.content["membership"] in ( + Membership.JOIN, + Membership.INVITE, + ) + + if is_joined: + return 200, {} + + join_rules = room_state.get((EventTypes.JoinRules, "")) + is_public = False + if join_rules: + is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC + + if is_public: + return 200, {} + + await self.room_member_handler.update_membership( + fake_requester, + target=UserID.from_string(user_to_add), + room_id=room_id, + action=Membership.INVITE, + ) + + return 200, {} diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py new file mode 100644
index 0000000000..f2490e382d --- /dev/null +++ b/synapse/rest/admin/statistics.py
@@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.storage.databases.main.stats import UserSortOrder +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class UserMediaStatisticsRestServlet(RestServlet): + """ + Get statistics about uploaded media by users. + """ + + PATTERNS = admin_patterns("/statistics/users/media$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + order_by = parse_string( + request, "order_by", default=UserSortOrder.USER_ID.value + ) + if order_by not in ( + UserSortOrder.MEDIA_LENGTH.value, + UserSortOrder.MEDIA_COUNT.value, + UserSortOrder.USER_ID.value, + UserSortOrder.DISPLAYNAME.value, + ): + raise SynapseError( + 400, + "Unknown value for order_by: %s" % (order_by,), + errcode=Codes.INVALID_PARAM, + ) + + start = parse_integer(request, "from", default=0) + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + limit = parse_integer(request, "limit", default=100) + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + from_ts = parse_integer(request, "from_ts", default=0) + if from_ts < 0: + raise SynapseError( + 400, + "Query parameter from_ts must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + until_ts = parse_integer(request, "until_ts") + if until_ts is not None: + if until_ts < 0: + raise SynapseError( + 400, + "Query parameter until_ts must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + if until_ts <= from_ts: + raise SynapseError( + 400, + "Query parameter until_ts must be greater than from_ts.", + errcode=Codes.INVALID_PARAM, + ) + + search_term = parse_string(request, "search_term") + if search_term == "": + raise SynapseError( + 400, + "Query parameter search_term cannot be an empty string.", + errcode=Codes.INVALID_PARAM, + ) + + direction = parse_string(request, "dir", default="f") + if direction not in ("f", "b"): + raise SynapseError( + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + ) + + users_media, total = await self.store.get_users_media_usage_paginate( + start, limit, from_ts, until_ts, order_by, direction, search_term + ) + ret = {"users": users_media, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(users_media) + + return 200, ret diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 20dc1d0e05..6658c2da56 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -16,6 +16,7 @@ import hashlib import hmac import logging from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -27,25 +28,29 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, assert_user_is_admin, - historical_admin_path_patterns, ) -from synapse.types import UserID +from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) class UsersRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$") + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$") def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() async def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) @@ -82,7 +87,7 @@ class UsersRestServletV2(RestServlet): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() async def on_GET(self, request): await assert_requester_is_admin(self.auth, request) @@ -135,7 +140,7 @@ class UserRestServletV2(RestServlet): def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() self.store = hs.get_datastore() self.auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() @@ -304,9 +309,9 @@ class UserRestServletV2(RestServlet): data={}, ) - if "avatar_url" in body and type(body["avatar_url"]) == str: + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( - user_id, requester, body["avatar_url"], True + target_user, requester, body["avatar_url"], True ) ret = await self.admin_handler.get_user(target_user) @@ -322,7 +327,7 @@ class UserRegisterServlet(RestServlet): nonce to the time it was generated, in int seconds. """ - PATTERNS = historical_admin_path_patterns("/register") + PATTERNS = admin_patterns("/register") NONCE_TIMEOUT = 60 def __init__(self, hs): @@ -399,10 +404,14 @@ class UserRegisterServlet(RestServlet): admin = body.get("admin", None) user_type = body.get("user_type", None) + displayname = body.get("displayname", None) if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: raise SynapseError(400, "Invalid user type") + if "mac" not in body: + raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON) + got_mac = body["mac"] want_mac_builder = hmac.new( @@ -435,6 +444,7 @@ class UserRegisterServlet(RestServlet): password_hash=password_hash, admin=bool(admin), user_type=user_type, + default_display_name=displayname, by_admin=True, ) @@ -443,12 +453,19 @@ class UserRegisterServlet(RestServlet): class WhoisRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)") + path_regex = "/whois/(?P<user_id>[^/]*)$" + PATTERNS = ( + admin_patterns(path_regex) + + + # URL for spec reason + # https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid + client_patterns("/admin" + path_regex, v1=True) + ) def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() - self.handlers = hs.get_handlers() + self.admin_handler = hs.get_admin_handler() async def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) @@ -461,13 +478,13 @@ class WhoisRestServlet(RestServlet): if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only whois a local user") - ret = await self.handlers.admin_handler.get_whois(target_user) + ret = await self.admin_handler.get_whois(target_user) return 200, ret class DeactivateAccountRestServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)") def __init__(self, hs): self._deactivate_account_handler = hs.get_deactivate_account_handler() @@ -498,7 +515,7 @@ class DeactivateAccountRestServlet(RestServlet): class AccountValidityRenewServlet(RestServlet): - PATTERNS = historical_admin_path_patterns("/account_validity/validity$") + PATTERNS = admin_patterns("/account_validity/validity$") def __init__(self, hs): """ @@ -541,9 +558,7 @@ class ResetPasswordRestServlet(RestServlet): 200 OK with empty object if success otherwise an error. """ - PATTERNS = historical_admin_path_patterns( - "/reset_password/(?P<target_user_id>[^/]*)" - ) + PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") def __init__(self, hs): self.store = hs.get_datastore() @@ -585,13 +600,12 @@ class SearchUsersRestServlet(RestServlet): 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)") + PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)") def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() - self.handlers = hs.get_handlers() async def on_GET(self, request, target_user_id): """Get request to search user table for specific users according to @@ -612,7 +626,7 @@ class SearchUsersRestServlet(RestServlet): term = parse_string(request, "term", required=True) logger.info("term: %s ", term) - ret = await self.handlers.store.search_users(term) + ret = await self.store.search_users(term) return 200, ret @@ -703,9 +717,160 @@ class UserMembershipRestServlet(RestServlet): if not self.is_mine(UserID.from_string(user_id)): raise SynapseError(400, "Can only lookup local users") + user = await self.store.get_user_by_id(user_id) + if user is None: + raise NotFoundError("Unknown user") + room_ids = await self.store.get_rooms_for_user(user_id) - if not room_ids: + ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} + return 200, ret + + +class PushersRestServlet(RestServlet): + """ + Gets information about all pushers for a specific `user_id`. + + Example: + http://localhost:8008/_synapse/admin/v1/users/ + @user:server/pushers + + Returns: + pushers: Dictionary containing pushers information. + total: Number of pushers in dictonary `pushers`. + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$") + + def __init__(self, hs): + self.is_mine = hs.is_mine + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only lookup local users") + + if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") - ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} + pushers = await self.store.get_pushers_by_user_id(user_id) + + filtered_pushers = [p.as_dict() for p in pushers] + + return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} + + +class UserMediaRestServlet(RestServlet): + """ + Gets information about all uploaded local media for a specific `user_id`. + + Example: + http://localhost:8008/_synapse/admin/v1/users/ + @user:server/media + + Args: + The parameters `from` and `limit` are required for pagination. + By default, a `limit` of 100 is used. + Returns: + A list of media and an integer representing the total number of + media that exist given for this user + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$") + + def __init__(self, hs): + self.is_mine = hs.is_mine + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only lookup local users") + + user = await self.store.get_user_by_id(user_id) + if user is None: + raise NotFoundError("Unknown user") + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + media, total = await self.store.get_local_media_by_user_paginate( + start, limit, user_id + ) + + ret = {"media": media, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(media) + return 200, ret + + +class UserTokenRestServlet(RestServlet): + """An admin API for logging in as a user. + + Example: + + POST /_synapse/admin/v1/users/@test:example.com/login + {} + + 200 OK + { + "access_token": "<some_token>" + } + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + async def on_POST(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + auth_user = requester.user + + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Only local users can be logged in as") + + body = parse_json_object_from_request(request, allow_empty_body=True) + + valid_until_ms = body.get("valid_until_ms") + if valid_until_ms and not isinstance(valid_until_ms, int): + raise SynapseError(400, "'valid_until_ms' parameter must be an int") + + if auth_user.to_string() == user_id: + raise SynapseError(400, "Cannot use admin API to login as self") + + token = await self.auth_handler.get_access_token_for_user_id( + user_id=auth_user.to_string(), + device_id=None, + valid_until_ms=valid_until_ms, + puppets_user_id=user_id, + ) + + return 200, {"access_token": token} diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index faabeeb91c..e5af26b176 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py
@@ -42,14 +42,13 @@ class ClientDirectoryServer(RestServlet): def __init__(self, hs): super().__init__() self.store = hs.get_datastore() - self.handlers = hs.get_handlers() + self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() async def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) - dir_handler = self.handlers.directory_handler - res = await dir_handler.get_association(room_alias) + res = await self.directory_handler.get_association(room_alias) return 200, res @@ -79,19 +78,19 @@ class ClientDirectoryServer(RestServlet): requester = await self.auth.get_user_by_req(request) - await self.handlers.directory_handler.create_association( + await self.directory_handler.create_association( requester, room_alias, room_id, servers ) return 200, {} async def on_DELETE(self, request, room_alias): - dir_handler = self.handlers.directory_handler - try: service = self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) - await dir_handler.delete_appservice_association(service, room_alias) + await self.directory_handler.delete_appservice_association( + service, room_alias + ) logger.info( "Application service at %s deleted alias %s", service.url, @@ -107,7 +106,7 @@ class ClientDirectoryServer(RestServlet): room_alias = RoomAlias.from_string(room_alias) - await dir_handler.delete_association(requester, room_alias) + await self.directory_handler.delete_association(requester, room_alias) logger.info( "User %s deleted alias %s", user.to_string(), room_alias.to_string() @@ -122,7 +121,7 @@ class ClientDirectoryListServer(RestServlet): def __init__(self, hs): super().__init__() self.store = hs.get_datastore() - self.handlers = hs.get_handlers() + self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() async def on_GET(self, request, room_id): @@ -138,7 +137,7 @@ class ClientDirectoryListServer(RestServlet): content = parse_json_object_from_request(request) visibility = content.get("visibility", "public") - await self.handlers.directory_handler.edit_published_room_list( + await self.directory_handler.edit_published_room_list( requester, room_id, visibility ) @@ -147,7 +146,7 @@ class ClientDirectoryListServer(RestServlet): async def on_DELETE(self, request, room_id): requester = await self.auth.get_user_by_req(request) - await self.handlers.directory_handler.edit_published_room_list( + await self.directory_handler.edit_published_room_list( requester, room_id, "private" ) @@ -162,7 +161,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): def __init__(self, hs): super().__init__() self.store = hs.get_datastore() - self.handlers = hs.get_handlers() + self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() def on_PUT(self, request, network_id, room_id): @@ -180,7 +179,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): 403, "Only appservices can edit the appservice published room list" ) - await self.handlers.directory_handler.edit_published_appservice_room_list( + await self.directory_handler.edit_published_appservice_room_list( requester.app_service.id, network_id, room_id, visibility ) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 1ecb77aa26..6de4078290 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py
@@ -67,9 +67,6 @@ class EventStreamRestServlet(RestServlet): return 200, chunk - def on_OPTIONS(self, request): - return 200, {} - class EventRestServlet(RestServlet): PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 3d1693d7ac..5f4c6703db 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -14,15 +14,11 @@ # limitations under the License. import logging -from typing import Awaitable, Callable, Dict, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService -from synapse.handlers.auth import ( - convert_client_dict_legacy_fields_to_identifier, - login_id_phone_to_thirdparty, -) from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -33,7 +29,9 @@ from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import JsonDict, UserID -from synapse.util.threepids import canonicalise_email + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -47,7 +45,7 @@ class LoginRestServlet(RestServlet): JWT_TYPE_DEPRECATED = "m.login.jwt" APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -67,7 +65,6 @@ class LoginRestServlet(RestServlet): self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - self.handlers = hs.get_handlers() self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), @@ -79,11 +76,6 @@ class LoginRestServlet(RestServlet): rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, ) - self._failed_attempts_ratelimiter = Ratelimiter( - clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - ) def on_GET(self, request: SynapseRequest): flows = [] @@ -111,28 +103,32 @@ class LoginRestServlet(RestServlet): ({"type": t} for t in self.auth_handler.get_supported_login_types()) ) - return 200, {"flows": flows} + flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) - def on_OPTIONS(self, request: SynapseRequest): - return 200, {} + return 200, {"flows": flows} async def on_POST(self, request: SynapseRequest): - self._address_ratelimiter.ratelimit(request.getClientIP()) - login_submission = parse_json_object_from_request(request) try: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: appservice = self.auth.get_appservice_by_req(request) + + if appservice.is_rate_limited(): + self._address_ratelimiter.ratelimit(request.getClientIP()) + result = await self._do_appservice_login(login_submission, appservice) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): + self._address_ratelimiter.ratelimit(request.getClientIP()) result = await self._do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: + self._address_ratelimiter.ratelimit(request.getClientIP()) result = await self._do_token_login(login_submission) else: + self._address_ratelimiter.ratelimit(request.getClientIP()) result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -142,32 +138,38 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result - def _get_qualified_user_id(self, identifier): - if identifier["type"] != "m.id.user": - raise SynapseError(400, "Unknown login identifier type") - if "user" not in identifier: - raise SynapseError(400, "User identifier is missing 'user' key") - - if identifier["user"].startswith("@"): - return identifier["user"] - else: - return UserID(identifier["user"], self.hs.hostname).to_string() - async def _do_appservice_login( self, login_submission: JsonDict, appservice: ApplicationService ): - logger.info( - "Got appservice login request with identifier: %r", - login_submission.get("identifier"), - ) + identifier = login_submission.get("identifier") + logger.info("Got appservice login request with identifier: %r", identifier) + + if not isinstance(identifier, dict): + raise SynapseError( + 400, "Invalid identifier in login submission", Codes.INVALID_PARAM + ) + + # this login flow only supports identifiers of type "m.id.user". + if identifier.get("type") != "m.id.user": + raise SynapseError( + 400, "Unknown login identifier type", Codes.INVALID_PARAM + ) - identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) - qualified_user_id = self._get_qualified_user_id(identifier) + user = identifier.get("user") + if not isinstance(user, str): + raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM) + + if user.startswith("@"): + qualified_user_id = user + else: + qualified_user_id = UserID(user, self.hs.hostname).to_string() if not appservice.is_interested_in_user(qualified_user_id): raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) - return await self._complete_login(qualified_user_id, login_submission) + return await self._complete_login( + qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited() + ) async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: """Handle non-token/saml/jwt logins @@ -188,91 +190,9 @@ class LoginRestServlet(RestServlet): login_submission.get("address"), login_submission.get("user"), ) - identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) - - # convert phone type identifiers to generic threepids - if identifier["type"] == "m.id.phone": - identifier = login_id_phone_to_thirdparty(identifier) - - # convert threepid identifiers to user IDs - if identifier["type"] == "m.id.thirdparty": - address = identifier.get("address") - medium = identifier.get("medium") - - if medium is None or address is None: - raise SynapseError(400, "Invalid thirdparty identifier") - - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See add_threepid in synapse/handlers/auth.py) - if medium == "email": - try: - address = canonicalise_email(address) - except ValueError as e: - raise SynapseError(400, str(e)) - - # We also apply account rate limiting using the 3PID as a key, as - # otherwise using 3PID bypasses the ratelimiting based on user ID. - self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False) - - # Check for login providers that support 3pid login types - ( - canonical_user_id, - callback_3pid, - ) = await self.auth_handler.check_password_provider_3pid( - medium, address, login_submission["password"] - ) - if canonical_user_id: - # Authentication through password provider and 3pid succeeded - - result = await self._complete_login( - canonical_user_id, login_submission, callback_3pid - ) - return result - - # No password providers were able to handle this 3pid - # Check local store - user_id = await self.hs.get_datastore().get_user_id_by_threepid( - medium, address - ) - if not user_id: - logger.warning( - "unknown 3pid identifier medium %s, address %r", medium, address - ) - # We mark that we've failed to log in here, as - # `check_password_provider_3pid` might have returned `None` due - # to an incorrect password, rather than the account not - # existing. - # - # If it returned None but the 3PID was bound then we won't hit - # this code path, which is fine as then the per-user ratelimit - # will kick in below. - self._failed_attempts_ratelimiter.can_do_action((medium, address)) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - identifier = {"type": "m.id.user", "user": user_id} - - # by this point, the identifier should be an m.id.user: if it's anything - # else, we haven't understood it. - qualified_user_id = self._get_qualified_user_id(identifier) - - # Check if we've hit the failed ratelimit (but don't update it) - self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), update=False + canonical_user_id, callback = await self.auth_handler.validate_login( + login_submission, ratelimit=True ) - - try: - canonical_user_id, callback = await self.auth_handler.validate_login( - identifier["user"], login_submission - ) - except LoginError: - # The user has failed to log in, so we need to update the rate - # limiter. Using `can_do_action` avoids us raising a ratelimit - # exception and masking the LoginError. The actual ratelimiting - # should have happened above. - self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower()) - raise - result = await self._complete_login( canonical_user_id, login_submission, callback ) @@ -284,6 +204,7 @@ class LoginRestServlet(RestServlet): login_submission: JsonDict, callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, create_non_existent_users: bool = False, + ratelimit: bool = True, ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -298,6 +219,7 @@ class LoginRestServlet(RestServlet): callback: Callback function to run after login. create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. + ratelimit: Whether to ratelimit the login request. Returns: result: Dictionary of account information after successful login. @@ -306,7 +228,8 @@ class LoginRestServlet(RestServlet): # Before we actually log them in we check if they've already logged in # too often. This happens here rather than before as we don't # necessarily know the user before now. - self._account_ratelimiter.ratelimit(user_id.lower()) + if ratelimit: + self._account_ratelimiter.ratelimit(user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index f792b50cdc..ad8cea49c6 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py
@@ -30,9 +30,6 @@ class LogoutRestServlet(RestServlet): self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - def on_OPTIONS(self, request): - return 200, {} - async def on_POST(self, request): requester = await self.auth.get_user_by_req(request, allow_expired=True) @@ -58,9 +55,6 @@ class LogoutAllRestServlet(RestServlet): self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - def on_OPTIONS(self, request): - return 200, {} - async def on_POST(self, request): requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 79d8e3057f..23a529f8e3 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py
@@ -86,9 +86,6 @@ class PresenceStatusRestServlet(RestServlet): return 200, {} - def on_OPTIONS(self, request): - return 200, {} - def register_servlets(hs, http_server): PresenceStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index b686cd671f..85a66458c5 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py
@@ -59,15 +59,14 @@ class ProfileDisplaynameRestServlet(RestServlet): try: new_name = content["displayname"] except Exception: - return 400, "Unable to parse name" + raise SynapseError( + code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON, + ) await self.profile_handler.set_displayname(user, requester, new_name, is_admin) return 200, {} - def on_OPTIONS(self, request, user_id): - return 200, {} - class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True) @@ -116,9 +115,6 @@ class ProfileAvatarURLRestServlet(RestServlet): return 200, {} - def on_OPTIONS(self, request, user_id): - return 200, {} - class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index f9eecb7cf5..241e535917 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py
@@ -155,9 +155,6 @@ class PushRuleRestServlet(RestServlet): else: raise UnrecognizedRequestError() - def on_OPTIONS(self, request, path): - return 200, {} - def notify_user(self, user_id): stream_id = self.store.get_max_push_rules_stream_id() self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 28dabf1c7a..89823fcc39 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py
@@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) -ALLOWED_KEYS = { - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", -} - class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) @@ -54,15 +43,10 @@ class PushersRestServlet(RestServlet): pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - filtered_pushers = [ - {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers - ] + filtered_pushers = [p.as_dict() for p in pushers] return 200, {"pushers": filtered_pushers} - def on_OPTIONS(self, _): - return 200, {} - class PushersSetRestServlet(RestServlet): PATTERNS = client_patterns("/pushers/set$", v1=True) @@ -140,9 +124,6 @@ class PushersSetRestServlet(RestServlet): return 200, {} - def on_OPTIONS(self, _): - return 200, {} - class PushersRemoveRestServlet(RestServlet): """ @@ -182,9 +163,6 @@ class PushersRemoveRestServlet(RestServlet): ) return None - def on_OPTIONS(self, _): - return 200, {} - def register_servlets(hs, http_server): PushersRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index b63389e5fe..5647e8c577 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py
@@ -18,7 +18,7 @@ import logging import re -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional from urllib import parse as urlparse from synapse.api.constants import EventTypes, Membership @@ -48,8 +48,7 @@ from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, from synapse.util import json_decoder from synapse.util.stringutils import random_string -MYPY = False -if MYPY: +if TYPE_CHECKING: import synapse.server logger = logging.getLogger(__name__) @@ -72,20 +71,6 @@ class RoomCreateRestServlet(TransactionRestServlet): def register(self, http_server): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) - # define CORS for all of /rooms in RoomCreateRestServlet for simplicity - http_server.register_paths( - "OPTIONS", - client_patterns("/rooms(?:/.*)?$", v1=True), - self.on_OPTIONS, - self.__class__.__name__, - ) - # define CORS for /createRoom[/txnid] - http_server.register_paths( - "OPTIONS", - client_patterns("/createRoom(?:/.*)?$", v1=True), - self.on_OPTIONS, - self.__class__.__name__, - ) def on_PUT(self, request, txn_id): set_tag("txn_id", txn_id) @@ -104,15 +89,11 @@ class RoomCreateRestServlet(TransactionRestServlet): user_supplied_config = parse_json_object_from_request(request) return user_supplied_config - def on_OPTIONS(self, request): - return 200, {} - # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(TransactionRestServlet): def __init__(self, hs): super().__init__(hs) - self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() @@ -798,7 +779,6 @@ class RoomMembershipRestServlet(TransactionRestServlet): class RoomRedactEventRestServlet(TransactionRestServlet): def __init__(self, hs): super().__init__(hs) - self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() @@ -903,7 +883,7 @@ class RoomAliasListServlet(RestServlet): def __init__(self, hs: "synapse.server.HomeServer"): super().__init__() self.auth = hs.get_auth() - self.directory_handler = hs.get_handlers().directory_handler + self.directory_handler = hs.get_directory_handler() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request) @@ -920,7 +900,7 @@ class SearchRestServlet(RestServlet): def __init__(self, hs): super().__init__() - self.handlers = hs.get_handlers() + self.search_handler = hs.get_search_handler() self.auth = hs.get_auth() async def on_POST(self, request): @@ -929,9 +909,7 @@ class SearchRestServlet(RestServlet): content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = await self.handlers.search_handler.search( - requester.user, content, batch - ) + results = await self.search_handler.search(requester.user, content, batch) return 200, results @@ -985,25 +963,28 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): ) -def register_servlets(hs, http_server): +def register_servlets(hs, http_server, is_worker=False): RoomStateEventRestServlet(hs).register(http_server) - RoomCreateRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) - RoomForgetRestServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) - SearchRestServlet(hs).register(http_server) - JoinedRoomsRestServlet(hs).register(http_server) - RoomEventServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) - RoomAliasListServlet(hs).register(http_server) + + # Some servlets only get registered for the main process. + if not is_worker: + RoomCreateRestServlet(hs).register(http_server) + RoomForgetRestServlet(hs).register(http_server) + SearchRestServlet(hs).register(http_server) + JoinedRoomsRestServlet(hs).register(http_server) + RoomEventServlet(hs).register(http_server) + RoomAliasListServlet(hs).register(http_server) def register_deprecated_servlets(hs, http_server): diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index b8d491ca5c..d07ca2c47c 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py
@@ -69,9 +69,6 @@ class VoipRestServlet(RestServlet): }, ) - def on_OPTIONS(self, request): - return 200, {} - def register_servlets(hs, http_server): VoipRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index ab5815e7f7..d837bde1d6 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py
@@ -38,6 +38,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -56,7 +57,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.hs = hs self.datastore = hs.get_datastore() self.config = hs.config - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( @@ -114,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) @@ -143,6 +144,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Wrap the session id in a JSON object ret = {"sid": sid} + threepid_send_requests.labels(type="email", reason="password_reset").observe( + send_attempt + ) + return 200, ret @@ -249,14 +254,18 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - # If we have a password in this request, prefer it. Otherwise, there - # must be a password hash from an earlier request. + # If we have a password in this request, prefer it. Otherwise, use the + # password hash from an earlier request. if new_password: password_hash = await self.auth_handler.hash(new_password) - else: + elif session_id is not None: password_hash = await self.auth_handler.get_session_data( session_id, "password_hash", None ) + else: + # UI validation was skipped, but the request did not include a new + # password. + password_hash = None if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) @@ -268,9 +277,6 @@ class PasswordRestServlet(RestServlet): return 200, {} - def on_OPTIONS(self, _): - return 200, {} - class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") @@ -327,7 +333,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): super().__init__() self.hs = hs self.config = hs.config - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.store = self.hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: @@ -385,7 +391,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -414,6 +420,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): # Wrap the session id in a JSON object ret = {"sid": sid} + threepid_send_requests.labels(type="email", reason="add_threepid").observe( + send_attempt + ) + return 200, ret @@ -424,7 +434,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): self.hs = hs super().__init__() self.store = self.hs.get_datastore() - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() async def on_POST(self, request): body = parse_json_object_from_request(request) @@ -460,7 +470,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) @@ -484,6 +494,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): next_link, ) + threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe( + send_attempt + ) + return 200, ret @@ -574,7 +588,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastore() - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() async def on_POST(self, request): if not self.config.account_threepid_delegate_msisdn: @@ -604,7 +618,7 @@ class ThreepidRestServlet(RestServlet): def __init__(self, hs): super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() @@ -660,7 +674,7 @@ class ThreepidAddRestServlet(RestServlet): def __init__(self, hs): super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -711,7 +725,7 @@ class ThreepidBindRestServlet(RestServlet): def __init__(self, hs): super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() async def on_POST(self, request): @@ -740,7 +754,7 @@ class ThreepidUnbindRestServlet(RestServlet): def __init__(self, hs): super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.datastore = self.hs.get_datastore() diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 5fbfae5991..fab077747f 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py
@@ -176,9 +176,6 @@ class AuthRestServlet(RestServlet): respond_with_html(request, 200, html) return None - def on_OPTIONS(self, _): - return 200, {} - def register_servlets(hs, http_server): AuthRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 7e174de692..af117cb27c 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +22,7 @@ from synapse.http.servlet import ( assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest from ._base import client_patterns, interactive_auth_handler @@ -151,7 +153,139 @@ class DeviceRestServlet(RestServlet): return 200, {} +class DehydratedDeviceServlet(RestServlet): + """Retrieve or store a dehydrated device. + + GET /org.matrix.msc2697.v2/dehydrated_device + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id", + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + PUT /org.matrix.msc2697/dehydrated_device + Content-Type: application/json + + { + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + """ + + PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_GET(self, request: SynapseRequest): + requester = await self.auth.get_user_by_req(request) + dehydrated_device = await self.device_handler.get_dehydrated_device( + requester.user.to_string() + ) + if dehydrated_device is not None: + (device_id, device_data) = dehydrated_device + result = {"device_id": device_id, "device_data": device_data} + return (200, result) + else: + raise errors.NotFoundError("No dehydrated device available") + + async def on_PUT(self, request: SynapseRequest): + submission = parse_json_object_from_request(request) + requester = await self.auth.get_user_by_req(request) + + if "device_data" not in submission: + raise errors.SynapseError( + 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM, + ) + elif not isinstance(submission["device_data"], dict): + raise errors.SynapseError( + 400, + "device_data must be an object", + errcode=errors.Codes.INVALID_PARAM, + ) + + device_id = await self.device_handler.store_dehydrated_device( + requester.user.to_string(), + submission["device_data"], + submission.get("initial_device_display_name", None), + ) + return 200, {"device_id": device_id} + + +class ClaimDehydratedDeviceServlet(RestServlet): + """Claim a dehydrated device. + + POST /org.matrix.msc2697.v2/dehydrated_device/claim + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "success": true, + } + + """ + + PATTERNS = client_patterns( + "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=() + ) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_POST(self, request: SynapseRequest): + requester = await self.auth.get_user_by_req(request) + + submission = parse_json_object_from_request(request) + + if "device_id" not in submission: + raise errors.SynapseError( + 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM, + ) + elif not isinstance(submission["device_id"], str): + raise errors.SynapseError( + 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM, + ) + + result = await self.device_handler.rehydrate_device( + requester.user.to_string(), + self.auth.get_access_token_from_request(request), + submission["device_id"], + ) + + return (200, result) + + def register_servlets(hs, http_server): DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server) + DehydratedDeviceServlet(hs).register(http_server) + ClaimDehydratedDeviceServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index a3bb095c2d..5b5da71815 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py
@@ -15,6 +15,7 @@ # limitations under the License. import logging +from functools import wraps from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -25,6 +26,22 @@ from ._base import client_patterns logger = logging.getLogger(__name__) +def _validate_group_id(f): + """Wrapper to validate the form of the group ID. + + Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. + """ + + @wraps(f) + def wrapper(self, request, group_id, *args, **kwargs): + if not GroupID.is_valid(group_id): + raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) + + return f(self, request, group_id, *args, **kwargs) + + return wrapper + + class GroupServlet(RestServlet): """Get the group profile """ @@ -37,6 +54,7 @@ class GroupServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -47,6 +65,7 @@ class GroupServlet(RestServlet): return 200, group_description + @_validate_group_id async def on_POST(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet): return 200, category + @_validate_group_id async def on_PUT(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet): return 200, category + @_validate_group_id async def on_PUT(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s was not legal group ID" % (group_id,)) - result = await self.groups_handler.get_rooms_in_group( group_id, requester_user_id ) @@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet): return 200, result + @_validate_group_id async def on_DELETE(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, room_id, config_key): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id + @_validate_group_id async def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -565,6 +603,7 @@ class GroupSelfLeaveServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -589,6 +628,7 @@ class GroupSelfJoinServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -613,6 +653,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -637,6 +678,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 55c4606569..b91996c738 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py
@@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -67,6 +68,7 @@ class KeyUploadServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.device_handler = hs.get_device_handler() @trace(opname="upload_keys") async def on_POST(self, request, device_id): @@ -75,23 +77,28 @@ class KeyUploadServlet(RestServlet): body = parse_json_object_from_request(request) if device_id is not None: - # passing the device_id here is deprecated; however, we allow it - # for now for compatibility with older clients. + # Providing the device_id should only be done for setting keys + # for dehydrated devices; however, we allow it for any device for + # compatibility with older clients. if requester.device_id is not None and device_id != requester.device_id: - set_tag("error", True) - log_kv( - { - "message": "Client uploading keys for a different device", - "logged_in_id": requester.device_id, - "key_being_uploaded": device_id, - } - ) - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, + dehydrated_device = await self.device_handler.get_dehydrated_device( + user_id ) + if dehydrated_device is not None and device_id != dehydrated_device[0]: + set_tag("error", True) + log_kv( + { + "message": "Client uploading keys for a different device", + "logged_in_id": requester.device_id, + "key_being_uploaded": device_id, + } + ) + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index ffa2dfce42..6b5a1b7109 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py
@@ -45,6 +45,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter @@ -78,7 +79,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): """ super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.config = hs.config if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: @@ -134,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -163,6 +164,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): # Wrap the session id in a JSON object ret = {"sid": sid} + threepid_send_requests.labels(type="email", reason="register").observe( + send_attempt + ) + return 200, ret @@ -176,7 +181,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): """ super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() async def on_POST(self, request): body = parse_json_object_from_request(request) @@ -209,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it # look like we did something. - await self.hs.clock.sleep(random.randint(1, 10) / 10) + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError( @@ -234,6 +239,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): next_link, ) + threepid_send_requests.labels(type="msisdn", reason="register").observe( + send_attempt + ) + return 200, ret @@ -370,7 +379,7 @@ class RegisterRestServlet(RestServlet): self.store = hs.get_datastore() self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() self.ratelimiter = hs.get_registration_ratelimiter() @@ -442,7 +451,7 @@ class RegisterRestServlet(RestServlet): # == Normal User Registration == (everyone else) if not self._registration_enabled: - raise SynapseError(403, "Registration has been disabled") + raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN) # For regular registration, convert the provided username to lowercase # before attempting to register it. This should mean that people who try @@ -642,16 +651,17 @@ class RegisterRestServlet(RestServlet): return 200, return_dict - def on_OPTIONS(self, _): - return 200, {} - async def _do_appservice_registration(self, username, as_token, body): user_id = await self.registration_handler.appservice_register( username, as_token ) - return await self._create_registration_details(user_id, body) + return await self._create_registration_details( + user_id, body, is_appservice_ghost=True, + ) - async def _create_registration_details(self, user_id, params): + async def _create_registration_details( + self, user_id, params, is_appservice_ghost=False + ): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -668,7 +678,11 @@ class RegisterRestServlet(RestServlet): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=False + user_id, + device_id, + initial_display_name, + is_guest=False, + is_appservice_ghost=is_appservice_ghost, ) result.update({"access_token": access_token, "device_id": device_id}) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index bc4f43639a..a3dee14ed4 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -17,7 +17,7 @@ import logging from typing import Tuple from synapse.http import servlet -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.logging.opentracing import set_tag, trace from synapse.rest.client.transactions import HttpTransactionCache @@ -54,6 +54,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) + assert_params_in_dict(content, ("messages",)) sender_user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 6779df952f..8e52e4cca4 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py
@@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet): ) with context: sync_result = await self.sync_handler.wait_for_sync_for_user( + requester, sync_config, since_token=since_token, timeout=timeout, @@ -236,6 +237,7 @@ class SyncRestServlet(RestServlet): "leave": sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, + "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types, "next_batch": await sync_result.next_batch.to_string(self.store), } diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index c16280f668..d8e8e48c1c 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py
@@ -66,7 +66,7 @@ class LocalKey(Resource): def __init__(self, hs): self.config = hs.config - self.clock = hs.clock + self.clock = hs.get_clock() self.update_response_body(self.clock.time_msec()) Resource.__init__(self) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f843f02454..c57ac22e58 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, Set +from typing import Dict from signedjson.sign import sign_json @@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource): time_now_ms = self.clock.time_msec() - cache_misses = {} # type: Dict[str, Set[str]] + # Note that the value is unused. + cache_misses = {} # type: Dict[str, Dict[str, int]] for (server_name, key_id, from_server), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] if not results and key_id is not None: - cache_misses.setdefault(server_name, set()).add(key_id) + cache_misses.setdefault(server_name, {})[key_id] = 0 continue if key_id is not None: @@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource): ) if miss: - cache_misses.setdefault(server_name, set()).add(key_id) + cache_misses.setdefault(server_name, {})[key_id] = 0 # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) else: diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 6568e61829..47c2b44bff 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py
@@ -155,6 +155,11 @@ def add_file_headers(request, media_type, file_size, upload_name): request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") request.setHeader(b"Content-Length", b"%d" % (file_size,)) + # Tell web crawlers to not index, archive, or follow links in media. This + # should help to prevent things in the media repo from showing up in web + # search results. + request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") + # separators as defined in RFC2616. SP and HT are handled separately. # see _can_encode_filename_as_token. @@ -213,6 +218,12 @@ async def respond_with_responder( file_size (int|None): Size in bytes of the media. If not known it should be None upload_name (str|None): The name of the requested file, if any. """ + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return + if not responder: respond_404(request) return diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 7447eeaebe..9e079f672f 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py
@@ -69,6 +69,23 @@ class MediaFilePaths: local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + def local_media_thumbnail_dir(self, media_id: str) -> str: + """ + Retrieve the local store path of thumbnails of a given media_id + + Args: + media_id: The media ID to query. + Returns: + Path of local_thumbnails from media_id + """ + return os.path.join( + self.base_path, + "local_thumbnails", + media_id[0:2], + media_id[2:4], + media_id[4:], + ) + def remote_media_filepath_rel(self, server_name, file_id): return os.path.join( "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index e1192b47cd..83beb02b05 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py
@@ -18,7 +18,7 @@ import errno import logging import os import shutil -from typing import IO, Dict, Optional, Tuple +from typing import IO, Dict, List, Optional, Tuple import twisted.internet.error import twisted.web.http @@ -66,7 +66,7 @@ class MediaRepository: def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() @@ -305,15 +305,12 @@ class MediaRepository: # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new # one. - if media_info: - file_id = media_info["filesystem_id"] - else: - file_id = random_string(24) - - file_info = FileInfo(server_name, file_id) # If we have an entry in the DB, try and look for it if media_info: + file_id = media_info["filesystem_id"] + file_info = FileInfo(server_name, file_id) + if media_info["quarantined_by"]: logger.info("Media is quarantined") raise NotFoundError() @@ -324,14 +321,34 @@ class MediaRepository: # Failed to find the file anywhere, lets download it. - media_info = await self._download_remote_file(server_name, media_id, file_id) + try: + media_info = await self._download_remote_file(server_name, media_id,) + except SynapseError: + raise + except Exception as e: + # An exception may be because we downloaded media in another + # process, so let's check if we magically have the media. + media_info = await self.store.get_cached_remote_media(server_name, media_id) + if not media_info: + raise e + + file_id = media_info["filesystem_id"] + file_info = FileInfo(server_name, file_id) + + # We generate thumbnails even if another process downloaded the media + # as a) it's conceivable that the other download request dies before it + # generates thumbnails, but mainly b) we want to be sure the thumbnails + # have finished being generated before responding to the client, + # otherwise they'll request thumbnails and get a 404 if they're not + # ready yet. + await self._generate_thumbnails( + server_name, media_id, file_id, media_info["media_type"] + ) responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file( - self, server_name: str, media_id: str, file_id: str - ) -> dict: + async def _download_remote_file(self, server_name: str, media_id: str,) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -346,6 +363,8 @@ class MediaRepository: The media info of the file. """ + file_id = random_string(24) + file_info = FileInfo(server_name=server_name, file_id=file_id) with self.media_storage.store_into_file(file_info) as (f, fname, finish): @@ -401,22 +420,32 @@ class MediaRepository: await finish() - media_type = headers[b"Content-Type"][0].decode("ascii") - upload_name = get_filename_from_headers(headers) - time_now_ms = self.clock.time_msec() + media_type = headers[b"Content-Type"][0].decode("ascii") + upload_name = get_filename_from_headers(headers) + time_now_ms = self.clock.time_msec() + + # Multiple remote media download requests can race (when using + # multiple media repos), so this may throw a violation constraint + # exception. If it does we'll delete the newly downloaded file from + # disk (as we're in the ctx manager). + # + # However: we've already called `finish()` so we may have also + # written to the storage providers. This is preferable to the + # alternative where we call `finish()` *after* this, where we could + # end up having an entry in the DB but fail to write the files to + # the storage providers. + await self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) logger.info("Stored remote media in file %r", fname) - await self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - ) - media_info = { "media_type": media_type, "media_length": length, @@ -425,8 +454,6 @@ class MediaRepository: "filesystem_id": file_id, } - await self._generate_thumbnails(server_name, media_id, file_id, media_type) - return media_info def _get_thumbnail_requirements(self, media_type): @@ -692,42 +719,60 @@ class MediaRepository: if not t_byte_source: continue - try: - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - thumbnail=True, - thumbnail_width=t_width, - thumbnail_height=t_height, - thumbnail_method=t_method, - thumbnail_type=t_type, - url_cache=url_cache, - ) - - output_path = await self.media_storage.store_file( - t_byte_source, file_info - ) - finally: - t_byte_source.close() - - t_len = os.path.getsize(output_path) + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + url_cache=url_cache, + ) - # Write to database - if server_name: - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - else: - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + try: + await self.media_storage.write_to_file(t_byte_source, f) + await finish() + finally: + t_byte_source.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + except Exception as e: + thumbnail_exists = await self.store.get_remote_media_thumbnail( + server_name, media_id, t_width, t_height, t_type, + ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) return {"width": m_width, "height": m_height} @@ -767,6 +812,76 @@ class MediaRepository: return {"deleted": deleted} + async def delete_local_media(self, media_id: str) -> Tuple[List[str], int]: + """ + Delete the given local or remote media ID from this server + + Args: + media_id: The media ID to delete. + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + return await self._remove_local_media_from_disk([media_id]) + + async def delete_old_local_media( + self, before_ts: int, size_gt: int = 0, keep_profiles: bool = True, + ) -> Tuple[List[str], int]: + """ + Delete local or remote media from this server by size and timestamp. Removes + media files, any thumbnails and cached URLs. + + Args: + before_ts: Unix timestamp in ms. + Files that were last used before this timestamp will be deleted + size_gt: Size of the media in bytes. Files that are larger will be deleted + keep_profiles: Switch to delete also files that are still used in image data + (e.g user profile, room avatar) + If false these files will be deleted + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + old_media = await self.store.get_local_media_before( + before_ts, size_gt, keep_profiles, + ) + return await self._remove_local_media_from_disk(old_media) + + async def _remove_local_media_from_disk( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: + """ + Delete local or remote media from this server. Removes media files, + any thumbnails and cached URLs. + + Args: + media_ids: List of media_id to delete + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + removed_media = [] + for media_id in media_ids: + logger.info("Deleting media with ID '%s'", media_id) + full_path = self.filepaths.local_media_filepath(media_id) + try: + os.remove(full_path) + except OSError as e: + logger.warning("Failed to remove file: %r: %s", full_path, e) + if e.errno == errno.ENOENT: + pass + else: + continue + + thumbnail_dir = self.filepaths.local_media_thumbnail_dir(media_id) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + await self.store.delete_remote_media(self.server_name, media_id) + + await self.store.delete_url_cache((media_id,)) + await self.store.delete_url_cache_media((media_id,)) + + removed_media.append(media_id) + + return removed_media, len(removed_media) + class MediaRepositoryResource(Resource): """File uploading and downloading. diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a9586fb0b7..268e0c8f50 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py
@@ -52,6 +52,7 @@ class MediaStorage: storage_providers: Sequence["StorageProviderWrapper"], ): self.hs = hs + self.reactor = hs.get_reactor() self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers @@ -70,13 +71,16 @@ class MediaStorage: with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - await defer_to_thread( - self.hs.get_reactor(), _write_file_synchronously, source, f - ) + await self.write_to_file(source, f) await finish_cb() return fname + async def write_to_file(self, source: IO, output: IO): + """Asynchronously write the `source` to `output`. + """ + await defer_to_thread(self.reactor, _write_file_synchronously, source, output) + @contextlib.contextmanager def store_into_file(self, file_info: FileInfo): """Context manager used to get a file like object to write into, as @@ -112,14 +116,20 @@ class MediaStorage: finished_called = [False] - async def finish(): - for provider in self.storage_providers: - await provider.store_file(path, file_info) - - finished_called[0] = True - try: with open(fname, "wb") as f: + + async def finish(): + # Ensure that all writes have been flushed and close the + # file. + f.flush() + f.close() + + for provider in self.storage_providers: + await provider.store_file(path, file_info) + + finished_called[0] = True + yield f, fname, finish except Exception: try: @@ -210,7 +220,7 @@ class MediaStorage: if res: with res: consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.hs.get_reactor() + open(local_path, "wb"), self.reactor ) await res.write_to_consumer(consumer) await consumer.wait() diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index dce6c4d168..1082389d9b 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -676,7 +676,11 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") -def decode_and_calc_og(body, media_uri, request_encoding=None): +def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]: + # If there's no body, nothing useful is going to be found. + if not body: + return {} + from lxml import etree try: diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 18c9ed48d6..67f67efde7 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging import os import shutil @@ -21,6 +20,7 @@ from typing import Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from synapse.util.async_helpers import maybe_awaitable from ._base import FileInfo, Responder from .media_storage import FileResponder @@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider): if self.store_synchronous: # store_file is supposed to return an Awaitable, but guard # against improper implementations. - result = self.backend.store_file(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable(self.backend.store_file(path, file_info)) else: # TODO: Handle errors. async def store(): try: - result = self.backend.store_file(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable( + self.backend.store_file(path, file_info) + ) except Exception: logger.exception("Error storing file") @@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider): async def fetch(self, path, file_info): # store_file is supposed to return an Awaitable, but guard # against improper implementations. - result = self.backend.fetch(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable(self.backend.fetch(path, file_info)) class FileStorageProviderBackend(StorageProvider): diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index d76f7389e1..42febc9afc 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py
@@ -44,7 +44,7 @@ class UploadResource(DirectServeJsonResource): requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point - content_length = request.getHeader(b"Content-Length").decode("ascii") + content_length = request.getHeader("Content-Length") if content_length is None: raise SynapseError(msg="Request must specify a Content-Length", code=400) if int(content_length) > self.max_upload_size: diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py new file mode 100644
index 0000000000..d3b6803e65 --- /dev/null +++ b/synapse/rest/synapse/client/pick_username.py
@@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +import pkg_resources + +from twisted.web.http import Request +from twisted.web.resource import Resource +from twisted.web.static import File + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME +from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource +from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +def pick_username_resource(hs: "HomeServer") -> Resource: + """Factory method to generate the username picker resource. + + This resource gets mounted under /_synapse/client/pick_username. The top-level + resource is just a File resource which serves up the static files in the resources + "res" directory, but it has a couple of children: + + * "submit", which does the mechanics of registering the new user, and redirects the + browser back to the client URL + + * "check": checks if a userid is free. + """ + + # XXX should we make this path customisable so that admins can restyle it? + base_path = pkg_resources.resource_filename("synapse", "res/username_picker") + + res = File(base_path) + res.putChild(b"submit", SubmitResource(hs)) + res.putChild(b"check", AvailabilityCheckResource(hs)) + + return res + + +class AvailabilityCheckResource(DirectServeJsonResource): + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_GET(self, request: Request): + localpart = parse_string(request, "username", required=True) + + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + + is_available = await self._sso_handler.check_username_availability( + localpart, session_id.decode("ascii", errors="replace") + ) + return 200, {"available": is_available} + + +class SubmitResource(DirectServeHtmlResource): + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_POST(self, request: SynapseRequest): + localpart = parse_string(request, "username", required=True) + + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + + await self._sso_handler.handle_submit_username_request( + request, localpart, session_id.decode("ascii", errors="replace") + ) diff --git a/synapse/server.py b/synapse/server.py
index 5e3752c333..a198b0eb46 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -27,7 +27,8 @@ import logging import os from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast -import twisted +import twisted.internet.base +import twisted.internet.tcp from twisted.mail.smtp import sendmail from twisted.web.iweb import IPolicyForHTTPS @@ -54,25 +55,28 @@ from synapse.federation.sender import FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler -from synapse.handlers import Handlers from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.acme import AcmeHandler +from synapse.handlers.admin import AdminHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGenerator from synapse.handlers.cas_handler import CasHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler from synapse.handlers.devicemessage import DeviceMessageHandler +from synapse.handlers.directory import DirectoryHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler from synapse.handlers.events import EventHandler, EventStreamHandler +from synapse.handlers.federation import FederationHandler from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler +from synapse.handlers.identity import IdentityHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.pagination import PaginationHandler from synapse.handlers.password_policy import PasswordPolicyHandler from synapse.handlers.presence import PresenceHandler -from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler +from synapse.handlers.profile import ProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.register import RegistrationHandler @@ -84,13 +88,16 @@ from synapse.handlers.room import ( from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler +from synapse.handlers.search import SearchHandler from synapse.handlers.set_password import SetPasswordHandler +from synapse.handlers.sso import SsoHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient +from synapse.module_api import ModuleApi from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool @@ -140,7 +147,8 @@ def cache_in_self(builder: T) -> T: "@cache_in_self can only be used on functions starting with `get_`" ) - depname = builder.__name__[len("get_") :] + # get_attr -> _attr + depname = builder.__name__[len("get") :] building = [False] @@ -185,14 +193,28 @@ class HomeServer(metaclass=abc.ABCMeta): we are listening on to provide HTTP services. """ - REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] + REQUIRED_ON_BACKGROUND_TASK_STARTUP = [ + "account_validity", + "auth", + "deactivate_account", + "message", + "pagination", + "profile", + "stats", + ] # This is overridden in derived application classes # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() - def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs): + def __init__( + self, + hostname: str, + config: HomeServerConfig, + reactor=None, + version_string="Synapse", + ): """ Args: hostname : The hostname for the server. @@ -214,21 +236,10 @@ class HomeServer(metaclass=abc.ABCMeta): self._instance_id = random_string(5) self._instance_name = config.worker_name or "master" - self.clock = Clock(reactor) - self.distributor = Distributor() - - self.registration_ratelimiter = Ratelimiter( - clock=self.clock, - rate_hz=config.rc_registration.per_second, - burst_count=config.rc_registration.burst_count, - ) + self.version_string = version_string self.datastores = None # type: Optional[Databases] - # Other kwargs are explicit dependencies - for depname in kwargs: - setattr(self, depname, kwargs[depname]) - def get_instance_id(self) -> str: """A unique ID for this synapse process instance. @@ -251,14 +262,20 @@ class HomeServer(metaclass=abc.ABCMeta): self.datastores = Databases(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") - def setup_master(self) -> None: + # Register background tasks required by this server. This must be done + # somewhat manually due to the background tasks not being registered + # unless handlers are instantiated. + if self.config.run_background_tasks: + self.setup_background_tasks() + + def setup_background_tasks(self) -> None: """ Some handlers have side effects on instantiation (like registering background updates). This function causes them to be fetched, and therefore instantiated, to run those side effects. """ - for i in self.REQUIRED_ON_MASTER_STARTUP: - getattr(self, "get_" + i)() + for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP: + getattr(self, "get_" + i + "_handler")() def get_reactor(self) -> twisted.internet.base.ReactorBase: """ @@ -276,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta): def is_mine_id(self, string: str) -> bool: return string.split(":", 1)[1] == self.hostname + @cache_in_self def get_clock(self) -> Clock: - return self.clock + return Clock(self._reactor) def get_datastore(self) -> DataStore: if not self.datastores: @@ -294,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta): def get_config(self) -> HomeServerConfig: return self.config + @cache_in_self def get_distributor(self) -> Distributor: - return self.distributor + return Distributor() + @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: - return self.registration_ratelimiter + return Ratelimiter( + clock=self.get_clock(), + rate_hz=self.config.rc_registration.per_second, + burst_count=self.config.rc_registration.burst_count, + ) @cache_in_self def get_federation_client(self) -> FederationClient: @@ -309,10 +333,6 @@ class HomeServer(metaclass=abc.ABCMeta): return FederationServer(self) @cache_in_self - def get_handlers(self) -> Handlers: - return Handlers(self) - - @cache_in_self def get_notifier(self) -> Notifier: return Notifier(self) @@ -330,10 +350,16 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_simple_http_client(self) -> SimpleHttpClient: + """ + An HTTP client with no special configuration. + """ return SimpleHttpClient(self) @cache_in_self def get_proxied_http_client(self) -> SimpleHttpClient: + """ + An HTTP client that uses configured HTTP(S) proxies. + """ return SimpleHttpClient( self, http_proxy=os.getenvb(b"http_proxy"), @@ -341,6 +367,30 @@ class HomeServer(metaclass=abc.ABCMeta): ) @cache_in_self + def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient: + """ + An HTTP client that uses configured HTTP(S) proxies and blacklists IPs + based on the IP range blacklist/whitelist. + """ + return SimpleHttpClient( + self, + ip_whitelist=self.config.ip_range_whitelist, + ip_blacklist=self.config.ip_range_blacklist, + http_proxy=os.getenvb(b"http_proxy"), + https_proxy=os.getenvb(b"HTTPS_PROXY"), + ) + + @cache_in_self + def get_federation_http_client(self) -> MatrixFederationHttpClient: + """ + An HTTP client for federation. + """ + tls_client_options_factory = context_factory.FederationPolicyForHTTPS( + self.config + ) + return MatrixFederationHttpClient(self, tls_client_options_factory) + + @cache_in_self def get_room_creation_handler(self) -> RoomCreationHandler: return RoomCreationHandler(self) @@ -372,6 +422,10 @@ class HomeServer(metaclass=abc.ABCMeta): return FollowerTypingHandler(self) @cache_in_self + def get_sso_handler(self) -> SsoHandler: + return SsoHandler(self) + + @cache_in_self def get_sync_handler(self) -> SyncHandler: return SyncHandler(self) @@ -399,6 +453,10 @@ class HomeServer(metaclass=abc.ABCMeta): return DeviceMessageHandler(self) @cache_in_self + def get_directory_handler(self) -> DirectoryHandler: + return DirectoryHandler(self) + + @cache_in_self def get_e2e_keys_handler(self) -> E2eKeysHandler: return E2eKeysHandler(self) @@ -411,6 +469,10 @@ class HomeServer(metaclass=abc.ABCMeta): return AcmeHandler(self) @cache_in_self + def get_admin_handler(self) -> AdminHandler: + return AdminHandler(self) + + @cache_in_self def get_application_service_api(self) -> ApplicationServiceApi: return ApplicationServiceApi(self) @@ -431,15 +493,20 @@ class HomeServer(metaclass=abc.ABCMeta): return EventStreamHandler(self) @cache_in_self + def get_federation_handler(self) -> FederationHandler: + return FederationHandler(self) + + @cache_in_self + def get_identity_handler(self) -> IdentityHandler: + return IdentityHandler(self) + + @cache_in_self def get_initial_sync_handler(self) -> InitialSyncHandler: return InitialSyncHandler(self) @cache_in_self def get_profile_handler(self): - if self.config.worker_app: - return BaseProfileHandler(self) - else: - return MasterProfileHandler(self) + return ProfileHandler(self) @cache_in_self def get_event_creation_handler(self) -> EventCreationHandler: @@ -450,6 +517,10 @@ class HomeServer(metaclass=abc.ABCMeta): return DeactivateAccountHandler(self) @cache_in_self + def get_search_handler(self) -> SearchHandler: + return SearchHandler(self) + + @cache_in_self def get_set_password_handler(self) -> SetPasswordHandler: return SetPasswordHandler(self) @@ -474,13 +545,6 @@ class HomeServer(metaclass=abc.ABCMeta): return PusherPool(self) @cache_in_self - def get_http_client(self) -> MatrixFederationHttpClient: - tls_client_options_factory = context_factory.FederationPolicyForHTTPS( - self.config - ) - return MatrixFederationHttpClient(self, tls_client_options_factory) - - @cache_in_self def get_media_repository_resource(self) -> MediaRepositoryResource: # build the media repo resource. This indirects through the HomeServer # to ensure that we only have a single instance of @@ -554,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta): return StatsHandler(self) @cache_in_self - def get_spam_checker(self): + def get_spam_checker(self) -> SpamChecker: return SpamChecker(self) @cache_in_self @@ -645,7 +709,11 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_federation_ratelimiter(self) -> FederationRateLimiter: - return FederationRateLimiter(self.clock, config=self.config.rc_federation) + return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation) + + @cache_in_self + def get_module_api(self) -> ModuleApi: + return ModuleApi(self, self.get_auth_handler()) async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 3673e7f47e..9137c4edb1 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py
@@ -104,7 +104,7 @@ class ConsentServerNotices: def copy_with_str_subst(x: Any, substitutions: Any) -> Any: - """Deep-copy a structure, carrying out string substitions on any strings + """Deep-copy a structure, carrying out string substitutions on any strings Args: x (object): structure to be copied diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 0422d4c7ce..100dbd5e2c 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py
@@ -39,6 +39,7 @@ class ServerNoticesManager: self._room_member_handler = hs.get_room_member_handler() self._event_creation_handler = hs.get_event_creation_handler() self._is_mine_id = hs.is_mine_id + self._server_name = hs.hostname self._notifier = hs.get_notifier() self.server_notices_mxid = self._config.server_notices_mxid @@ -72,7 +73,9 @@ class ServerNoticesManager: await self.maybe_invite_user_to_room(user_id, room_id) system_mxid = self._config.server_notices_mxid - requester = create_requester(system_mxid) + requester = create_requester( + system_mxid, authenticated_entity=self._server_name + ) logger.info("Sending server notice to %s", user_id) @@ -119,7 +122,7 @@ class ServerNoticesManager: # manages to invite the system user to a room, that doesn't make it # the server notices room. user_ids = await self._store.get_users_in_room(room.room_id) - if self.server_notices_mxid in user_ids: + if len(user_ids) <= 2 and self.server_notices_mxid in user_ids: # we found a room which our user shares with the system notice # user logger.info( @@ -145,7 +148,9 @@ class ServerNoticesManager: "avatar_url": self._config.server_notices_mxid_avatar_url, } - requester = create_requester(self.server_notices_mxid) + requester = create_requester( + self.server_notices_mxid, authenticated_entity=self._server_name + ) info, _ = await self._room_creation_handler.create_room( requester, config={ @@ -174,7 +179,9 @@ class ServerNoticesManager: user_id: The ID of the user to invite. room_id: The ID of the room to invite the user to. """ - requester = create_requester(self.server_notices_mxid) + requester = create_requester( + self.server_notices_mxid, authenticated_entity=self._server_name + ) # Check whether the user has already joined or been invited to this room. If # that's the case, there is no need to re-invite them. diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 395ac5ab02..3ce25bb012 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py
@@ -12,19 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging from enum import Enum -from twisted.internet import defer - -from synapse.storage.state import StateFilter - -MYPY = False -if MYPY: - import synapse.server - -logger = logging.getLogger(__name__) - class RegistrationBehaviour(Enum): """ @@ -34,35 +23,3 @@ class RegistrationBehaviour(Enum): ALLOW = "allow" SHADOW_BAN = "shadow_ban" DENY = "deny" - - -class SpamCheckerApi: - """A proxy object that gets passed to spam checkers so they can get - access to rooms and other relevant information. - """ - - def __init__(self, hs: "synapse.server.HomeServer"): - self.hs = hs - - self._store = hs.get_datastore() - - @defer.inlineCallbacks - def get_state_events_in_room(self, room_id: str, types: tuple) -> defer.Deferred: - """Gets state events for the given room. - - Args: - room_id: The room ID to get state events in. - types: The event type and state key (using None - to represent 'any') of the room state to acquire. - - Returns: - twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: - The filtered state events in the room. - """ - state_ids = yield defer.ensureDeferred( - self._store.get_filtered_current_state_ids( - room_id=room_id, state_filter=StateFilter.from_types(types) - ) - ) - state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) - return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 31082bb16a..84f59c7d85 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py
@@ -547,7 +547,7 @@ class StateResolutionHandler: event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be - used as a starting point fof finding the state we need; any missing + used as a starting point for finding the state we need; any missing events will be requested via state_res_store. If None, all events will be fetched via state_res_store. @@ -738,7 +738,7 @@ def _make_state_cache_entry( # failing that, look for the closest match. prev_group = None - delta_ids = None + delta_ids = None # type: Optional[StateMap[str]] for old_group, old_state in state_groups_ids.items(): n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} @@ -783,7 +783,7 @@ class StateResolutionStore: ) def get_auth_chain_difference( - self, state_sets: List[Set[str]] + self, room_id: str, state_sets: List[Set[str]] ) -> Awaitable[Set[str]]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -796,4 +796,4 @@ class StateResolutionStore: An awaitable that resolves to a set of event IDs. """ - return self.store.get_auth_chain_difference(state_sets) + return self.store.get_auth_chain_difference(room_id, state_sets) diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index a493279cbd..85edae053d 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py
@@ -56,7 +56,7 @@ async def resolve_events_with_store( event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be - used as a starting point fof finding the state we need; any missing + used as a starting point for finding the state we need; any missing events will be requested via state_map_factory. If None, all events will be fetched via state_map_factory. diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index edf94e7ad6..e585954bd8 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase -from synapse.types import MutableStateMap, StateMap +from synapse.types import Collection, MutableStateMap, StateMap from synapse.util import Clock logger = logging.getLogger(__name__) @@ -69,7 +69,7 @@ async def resolve_events_with_store( event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be - used as a starting point fof finding the state we need; any missing + used as a starting point for finding the state we need; any missing events will be requested via state_res_store. If None, all events will be fetched via state_res_store. @@ -97,7 +97,9 @@ async def resolve_events_with_store( # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store) + auth_diff = await _get_auth_chain_difference( + room_id, state_sets, event_map, state_res_store + ) full_conflicted_set = set( itertools.chain( @@ -236,6 +238,7 @@ async def _get_power_level_for_sender( async def _get_auth_chain_difference( + room_id: str, state_sets: Sequence[StateMap[str]], event_map: Dict[str, EventBase], state_res_store: "synapse.state.StateResolutionStore", @@ -252,9 +255,90 @@ async def _get_auth_chain_difference( Set of event IDs """ + # The `StateResolutionStore.get_auth_chain_difference` function assumes that + # all events passed to it (and their auth chains) have been persisted + # previously. This is not the case for any events in the `event_map`, and so + # we need to manually handle those events. + # + # We do this by: + # 1. calculating the auth chain difference for the state sets based on the + # events in `event_map` alone + # 2. replacing any events in the state_sets that are also in `event_map` + # with their auth events (recursively), and then calling + # `store.get_auth_chain_difference` as normal + # 3. adding the results of 1 and 2 together. + + # Map from event ID in `event_map` to their auth event IDs, and their auth + # event IDs if they appear in the `event_map`. This is the intersection of + # the event's auth chain with the events in the `event_map` *plus* their + # auth event IDs. + events_to_auth_chain = {} # type: Dict[str, Set[str]] + for event in event_map.values(): + chain = {event.event_id} + events_to_auth_chain[event.event_id] = chain + + to_search = [event] + while to_search: + for auth_id in to_search.pop().auth_event_ids(): + chain.add(auth_id) + auth_event = event_map.get(auth_id) + if auth_event: + to_search.append(auth_event) + + # We now a) calculate the auth chain difference for the unpersisted events + # and b) work out the state sets to pass to the store. + # + # Note: If the `event_map` is empty (which is the common case), we can do a + # much simpler calculation. + if event_map: + # The list of state sets to pass to the store, where each state set is a set + # of the event ids making up the state. This is similar to `state_sets`, + # except that (a) we only have event ids, not the complete + # ((type, state_key)->event_id) mappings; and (b) we have stripped out + # unpersisted events and replaced them with the persisted events in + # their auth chain. + state_sets_ids = [] # type: List[Set[str]] + + # For each state set, the unpersisted event IDs reachable (by their auth + # chain) from the events in that set. + unpersisted_set_ids = [] # type: List[Set[str]] + + for state_set in state_sets: + set_ids = set() # type: Set[str] + state_sets_ids.append(set_ids) + + unpersisted_ids = set() # type: Set[str] + unpersisted_set_ids.append(unpersisted_ids) + + for event_id in state_set.values(): + event_chain = events_to_auth_chain.get(event_id) + if event_chain is not None: + # We have an event in `event_map`. We add all the auth + # events that it references (that aren't also in `event_map`). + set_ids.update(e for e in event_chain if e not in event_map) + + # We also add the full chain of unpersisted event IDs + # referenced by this state set, so that we can work out the + # auth chain difference of the unpersisted events. + unpersisted_ids.update(e for e in event_chain if e in event_map) + else: + set_ids.add(event_id) + + # The auth chain difference of the unpersisted events of the state sets + # is calculated by taking the difference between the union and + # intersections. + union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) + intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) + + difference_from_event_map = union - intersection # type: Collection[str] + else: + difference_from_event_map = () + state_sets_ids = [set(state_set.values()) for state_set in state_sets] + difference = await state_res_store.get_auth_chain_difference( - [set(state_set.values()) for state_set in state_sets] + room_id, state_sets_ids ) + difference.update(difference_from_event_map) return difference @@ -574,7 +658,7 @@ async def _get_mainline_depth_for_event( # We do an iterative search, replacing `event with the power level in its # auth events (if any) while tmp_event: - depth = mainline_map.get(event.event_id) + depth = mainline_map.get(tmp_event.event_id) if depth is not None: return depth diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js
index 3678670ec7..744800ec77 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js
@@ -182,7 +182,7 @@ matrixLogin.passwordLogin = function() { }; /* - * The onLogin function gets called after a succesful login. + * The onLogin function gets called after a successful login. * * It is expected that implementations override this to be notified when the * login is complete. The response to the login call is provided as the single diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index bbff3c8d5b..c0d9d1240f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py
@@ -27,6 +27,7 @@ There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ +from typing import TYPE_CHECKING from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore @@ -34,14 +35,18 @@ from synapse.storage.persist_events import EventsPersistenceStorage from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.state import StateGroupStorage -__all__ = ["DataStores", "DataStore"] +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + + +__all__ = ["Databases", "DataStore"] class Storage: """The high level interfaces for talking to various storage layers. """ - def __init__(self, hs, stores: Databases): + def __init__(self, hs: "HomeServer", stores: Databases): # We include the main data store here mainly so that we don't have to # rewrite all the existing code to split it into high vs low level # interfaces. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ab49d227de..a25c4093bc 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -17,14 +17,18 @@ import logging import random from abc import ABCMeta -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Iterable, Optional, Union from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool -from synapse.types import Collection, get_domain_from_id +from synapse.storage.types import Connection +from synapse.types import Collection, StreamToken, get_domain_from_id from synapse.util import json_decoder +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta): per data store (and not one per physical database). """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine self.db_pool = database self.rand = random.SystemRandom() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: StreamToken, + rows: Iterable[Any], + ) -> None: pass - def _invalidate_state_caches(self, room_id, members_changed): + def _invalidate_state_caches( + self, room_id: str, members_changed: Iterable[str] + ) -> None: """Invalidates caches that are based on the current state, but does not stream invalidations down replication. Args: - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have - changed + room_id: Room where state changed + members_changed: The user_ids of members that have changed """ for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) @@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta): def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] - ): + ) -> None: """Attempts to invalidate the cache of the given name, ignoring if the cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. @@ -76,22 +87,27 @@ class SQLBaseStore(metaclass=ABCMeta): """ try: - if key is None: - getattr(self, cache_name).invalidate_all() - else: - getattr(self, cache_name).invalidate(tuple(key)) + cache = getattr(self, cache_name) except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. - pass + return + + if key is None: + cache.invalidate_all() + else: + cache.invalidate(tuple(key)) -def db_to_json(db_content): +def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: """ Take some data from a database row and return a JSON-decoded object. Args: - db_content (memoryview|buffer|bytes|bytearray|unicode) + db_content: The JSON-encoded contents from the database. + + Returns: + The object decoded from JSON. """ # psycopg2 on Python 3 returns memoryview objects, which we need to # cast to bytes to decode diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 810721ebe9..29b8ca676a 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -12,29 +12,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.types import Connection +from synapse.types import JsonDict from synapse.util import json_encoder from . import engines +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.storage.database import DatabasePool, LoggingTransaction + logger = logging.getLogger(__name__) class BackgroundUpdatePerformance: """Tracks the how long a background update is taking to update its items""" - def __init__(self, name): + def __init__(self, name: str): self.name = name self.total_item_count = 0 - self.total_duration_ms = 0 - self.avg_item_count = 0 - self.avg_duration_ms = 0 + self.total_duration_ms = 0.0 + self.avg_item_count = 0.0 + self.avg_duration_ms = 0.0 - def update(self, item_count, duration_ms): + def update(self, item_count: int, duration_ms: float) -> None: """Update the stats after doing an update""" self.total_item_count += item_count self.total_duration_ms += duration_ms @@ -44,7 +49,7 @@ class BackgroundUpdatePerformance: self.avg_item_count += 0.1 * (item_count - self.avg_item_count) self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms) - def average_items_per_ms(self): + def average_items_per_ms(self) -> Optional[float]: """An estimate of how long it takes to do a single update. Returns: A duration in ms as a float @@ -58,7 +63,7 @@ class BackgroundUpdatePerformance: # changes in how long the update process takes. return float(self.avg_item_count) / float(self.avg_duration_ms) - def total_items_per_ms(self): + def total_items_per_ms(self) -> Optional[float]: """An estimate of how long it takes to do a single update. Returns: A duration in ms as a float @@ -83,21 +88,25 @@ class BackgroundUpdater: BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, hs, database): + def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database # if a background update is currently running, its name. self._current_background_update = None # type: Optional[str] - self._background_update_performance = {} - self._background_update_handlers = {} + self._background_update_performance = ( + {} + ) # type: Dict[str, BackgroundUpdatePerformance] + self._background_update_handlers = ( + {} + ) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]] self._all_done = False - def start_doing_background_updates(self): + def start_doing_background_updates(self) -> None: run_as_background_process("background_updates", self.run_background_updates) - async def run_background_updates(self, sleep=True): + async def run_background_updates(self, sleep: bool = True) -> None: logger.info("Starting background schema updates") while True: if sleep: @@ -148,7 +157,7 @@ class BackgroundUpdater: return False - async def has_completed_background_update(self, update_name) -> bool: + async def has_completed_background_update(self, update_name: str) -> bool: """Check if the given background update has finished running. """ if self._all_done: @@ -173,8 +182,7 @@ class BackgroundUpdater: Returns once some amount of work is done. Args: - desired_duration_ms(float): How long we want to spend - updating. + desired_duration_ms: How long we want to spend updating. Returns: True if we have finished running all the background updates, otherwise False """ @@ -220,6 +228,7 @@ class BackgroundUpdater: return False async def _do_background_update(self, desired_duration_ms: float) -> int: + assert self._current_background_update is not None update_name = self._current_background_update logger.info("Starting update batch on background update '%s'", update_name) @@ -273,7 +282,11 @@ class BackgroundUpdater: return len(self._background_update_performance) - def register_background_update_handler(self, update_name, update_handler): + def register_background_update_handler( + self, + update_name: str, + update_handler: Callable[[JsonDict, int], Awaitable[int]], + ): """Register a handler for doing a background update. The handler should take two arguments: @@ -287,12 +300,12 @@ class BackgroundUpdater: The handler is responsible for updating the progress of the update. Args: - update_name(str): The name of the update that this code handles. - update_handler(function): The function that does the update. + update_name: The name of the update that this code handles. + update_handler: The function that does the update. """ self._background_update_handlers[update_name] = update_handler - def register_noop_background_update(self, update_name): + def register_noop_background_update(self, update_name: str) -> None: """Register a noop handler for a background update. This is useful when we previously did a background update, but no @@ -302,10 +315,10 @@ class BackgroundUpdater: also be called to clear the update. Args: - update_name (str): Name of update + update_name: Name of update """ - async def noop_update(progress, batch_size): + async def noop_update(progress: JsonDict, batch_size: int) -> int: await self._end_background_update(update_name) return 1 @@ -313,14 +326,14 @@ class BackgroundUpdater: def register_background_index_update( self, - update_name, - index_name, - table, - columns, - where_clause=None, - unique=False, - psql_only=False, - ): + update_name: str, + index_name: str, + table: str, + columns: Iterable[str], + where_clause: Optional[str] = None, + unique: bool = False, + psql_only: bool = False, + ) -> None: """Helper for store classes to do a background index addition To use: @@ -332,19 +345,19 @@ class BackgroundUpdater: 2. In the Store constructor, call this method Args: - update_name (str): update_name to register for - index_name (str): name of index to add - table (str): table to add index to - columns (list[str]): columns/expressions to include in index - unique (bool): true to make a UNIQUE index + update_name: update_name to register for + index_name: name of index to add + table: table to add index to + columns: columns/expressions to include in index + unique: true to make a UNIQUE index psql_only: true to only create this index on psql databases (useful for virtual sqlite tables) """ - def create_index_psql(conn): + def create_index_psql(conn: Connection) -> None: conn.rollback() # postgres insists on autocommit for the index - conn.set_session(autocommit=True) + conn.set_session(autocommit=True) # type: ignore try: c = conn.cursor() @@ -371,9 +384,9 @@ class BackgroundUpdater: logger.debug("[SQL] %s", sql) c.execute(sql) finally: - conn.set_session(autocommit=False) + conn.set_session(autocommit=False) # type: ignore - def create_index_sqlite(conn): + def create_index_sqlite(conn: Connection) -> None: # Sqlite doesn't support concurrent creation of indexes. # # We don't use partial indices on SQLite as it wasn't introduced @@ -399,7 +412,7 @@ class BackgroundUpdater: c.execute(sql) if isinstance(self.db_pool.engine, engines.PostgresEngine): - runner = create_index_psql + runner = create_index_psql # type: Optional[Callable[[Connection], None]] elif psql_only: runner = None else: @@ -433,7 +446,9 @@ class BackgroundUpdater: "background_updates", keyvalues={"update_name": update_name} ) - async def _background_update_progress(self, update_name: str, progress: dict): + async def _background_update_progress( + self, update_name: str, progress: dict + ) -> None: """Update the progress of a background update Args: @@ -441,20 +456,22 @@ class BackgroundUpdater: progress: The progress of the update. """ - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "background_update_progress", self._background_update_progress_txn, update_name, progress, ) - def _background_update_progress_txn(self, txn, update_name, progress): + def _background_update_progress_txn( + self, txn: "LoggingTransaction", update_name: str, progress: JsonDict + ) -> None: """Update the progress of a background update Args: - txn(cursor): The transaction. - update_name(str): The name of the background update task - progress(dict): The progress of the update. + txn: The transaction. + update_name: The name of the background update task + progress: The progress of the update. """ progress_json = json_encoder.encode(progress) diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6116191b16..d1b5760c2c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from typing import ( overload, ) +import attr from prometheus_client import Histogram from typing_extensions import Literal @@ -87,16 +88,25 @@ def make_pool( """Get the connection pool for the database. """ + # By default enable `cp_reconnect`. We need to fiddle with db_args in case + # someone has explicitly set `cp_reconnect`. + db_args = dict(db_config.config.get("args", {})) + db_args.setdefault("cp_reconnect", True) + return adbapi.ConnectionPool( db_config.config["name"], cp_reactor=reactor, - cp_openfun=engine.on_new_connection, - **db_config.config.get("args", {}) + cp_openfun=lambda conn: engine.on_new_connection( + LoggingDatabaseConnection(conn, engine, "on_new_connection") + ), + **db_args, ) def make_conn( - db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + db_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, + default_txn_name: str, ) -> Connection: """Make a new connection to the database and return it. @@ -109,11 +119,60 @@ def make_conn( for k, v in db_config.config.get("args", {}).items() if not k.startswith("cp_") } - db_conn = engine.module.connect(**db_params) + native_db_conn = engine.module.connect(**db_params) + db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name) + engine.on_new_connection(db_conn) return db_conn +@attr.s(slots=True) +class LoggingDatabaseConnection: + """A wrapper around a database connection that returns `LoggingTransaction` + as its cursor class. + + This is mainly used on startup to ensure that queries get logged correctly + """ + + conn = attr.ib(type=Connection) + engine = attr.ib(type=BaseDatabaseEngine) + default_txn_name = attr.ib(type=str) + + def cursor( + self, *, txn_name=None, after_callbacks=None, exception_callbacks=None + ) -> "LoggingTransaction": + if not txn_name: + txn_name = self.default_txn_name + + return LoggingTransaction( + self.conn.cursor(), + name=txn_name, + database_engine=self.engine, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, + ) + + def close(self) -> None: + self.conn.close() + + def commit(self) -> None: + self.conn.commit() + + def rollback(self, *args, **kwargs) -> None: + self.conn.rollback(*args, **kwargs) + + def __enter__(self) -> "Connection": + self.conn.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: + return self.conn.__exit__(exc_type, exc_value, traceback) + + # Proxy through any unknown lookups to the DB conn class. + def __getattr__(self, name): + return getattr(self.conn, name) + + # The type of entry which goes on our after_callbacks and exception_callbacks lists. # # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so @@ -247,6 +306,12 @@ class LoggingTransaction: def close(self) -> None: self.txn.close() + def __enter__(self) -> "LoggingTransaction": + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + class PerformanceCounters: def __init__(self): @@ -395,7 +460,7 @@ class DatabasePool: def new_transaction( self, - conn: Connection, + conn: LoggingDatabaseConnection, desc: str, after_callbacks: List[_CallbackListEntry], exception_callbacks: List[_CallbackListEntry], @@ -436,12 +501,10 @@ class DatabasePool: i = 0 N = 5 while True: - cursor = LoggingTransaction( - conn.cursor(), - name, - self.engine, - after_callbacks, - exception_callbacks, + cursor = conn.cursor( + txn_name=name, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, ) try: r = func(cursor, *args, **kwargs) @@ -574,7 +637,7 @@ class DatabasePool: func, *args, db_autocommit=db_autocommit, - **kwargs + **kwargs, ) for after_callback, after_args, after_kwargs in after_callbacks: @@ -638,7 +701,10 @@ class DatabasePool: if db_autocommit: self.engine.attempt_to_set_autocommit(conn, True) - return func(conn, *args, **kwargs) + db_conn = LoggingDatabaseConnection( + conn, self.engine, "runWithConnection" + ) + return func(db_conn, *args, **kwargs) finally: if db_autocommit: self.engine.attempt_to_set_autocommit(conn, False) @@ -832,6 +898,12 @@ class DatabasePool: attempts = 0 while True: try: + # We can autocommit if we are going to use native upserts + autocommit = ( + self.engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ) + return await self.runInteraction( desc, self.simple_upsert_txn, @@ -840,6 +912,7 @@ class DatabasePool: values, insertion_values, lock=lock, + db_autocommit=autocommit, ) except self.engine.module.IntegrityError as e: attempts += 1 @@ -1002,6 +1075,43 @@ class DatabasePool: ) txn.execute(sql, list(allvalues.values())) + async def simple_upsert_many( + self, + table: str, + key_names: Collection[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[Any]], + desc: str, + ) -> None: + """ + Upsert, many times. + + Args: + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. + """ + + # We can autocommit if we are going to use native upserts + autocommit = ( + self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables + ) + + return await self.runInteraction( + desc, + self.simple_upsert_many_txn, + table, + key_names, + key_values, + value_names, + value_values, + db_autocommit=autocommit, + ) + def simple_upsert_many_txn( self, txn: LoggingTransaction, @@ -1153,7 +1263,13 @@ class DatabasePool: desc: description of the transaction, for logging and metrics """ return await self.runInteraction( - desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none + desc, + self.simple_select_one_txn, + table, + keyvalues, + retcols, + allow_none, + db_autocommit=True, ) @overload @@ -1204,6 +1320,7 @@ class DatabasePool: keyvalues, retcol, allow_none=allow_none, + db_autocommit=True, ) @overload @@ -1285,7 +1402,12 @@ class DatabasePool: Results in a list """ return await self.runInteraction( - desc, self.simple_select_onecol_txn, table, keyvalues, retcol + desc, + self.simple_select_onecol_txn, + table, + keyvalues, + retcol, + db_autocommit=True, ) async def simple_select_list( @@ -1310,7 +1432,12 @@ class DatabasePool: A list of dictionaries. """ return await self.runInteraction( - desc, self.simple_select_list_txn, table, keyvalues, retcols + desc, + self.simple_select_list_txn, + table, + keyvalues, + retcols, + db_autocommit=True, ) @classmethod @@ -1389,6 +1516,7 @@ class DatabasePool: chunk, keyvalues, retcols, + db_autocommit=True, ) results.extend(rows) @@ -1487,7 +1615,12 @@ class DatabasePool: desc: description of the transaction, for logging and metrics """ await self.runInteraction( - desc, self.simple_update_one_txn, table, keyvalues, updatevalues + desc, + self.simple_update_one_txn, + table, + keyvalues, + updatevalues, + db_autocommit=True, ) @classmethod @@ -1546,7 +1679,9 @@ class DatabasePool: keyvalues: dict of column names and values to select the row with desc: description of the transaction, for logging and metrics """ - await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + await self.runInteraction( + desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True, + ) @staticmethod def simple_delete_one_txn( @@ -1585,7 +1720,9 @@ class DatabasePool: Returns: The number of deleted rows. """ - return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + return await self.runInteraction( + desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True + ) @staticmethod def simple_delete_txn( @@ -1633,7 +1770,13 @@ class DatabasePool: Number rows deleted """ return await self.runInteraction( - desc, self.simple_delete_many_txn, table, column, iterable, keyvalues + desc, + self.simple_delete_many_txn, + table, + column, + iterable, + keyvalues, + db_autocommit=True, ) @staticmethod @@ -1678,7 +1821,7 @@ class DatabasePool: def get_cache_dict( self, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, table: str, entity_column: str, stream_column: str, @@ -1699,9 +1842,7 @@ class DatabasePool: "limit": limit, } - sql = self.engine.convert_param_style(sql) - - txn = db_conn.cursor() + txn = db_conn.cursor(txn_name="get_cache_dict") txn.execute(sql, (int(max_value),)) cache = {row[0]: int(row[1]) for row in txn} @@ -1801,7 +1942,13 @@ class DatabasePool: """ return await self.runInteraction( - desc, self.simple_search_list_txn, table, term, col, retcols + desc, + self.simple_search_list_txn, + table, + term, + col, + retcols, + db_autocommit=True, ) @classmethod diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index aa5d490624..0c24325011 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py
@@ -46,7 +46,7 @@ class Databases: db_name = database_config.name engine = create_engine(database_config.config) - with make_conn(database_config, engine) as db_conn: + with make_conn(database_config, engine, "startup") as db_conn: logger.info("[database config %r]: Checking database server", db_name) engine.check_database(db_conn) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0cb12f4c61..701748f93b 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import calendar import logging -import time from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import PresenceState @@ -148,13 +146,9 @@ class DataStore( db_conn, "e2e_cross_signing_keys", "stream_id" ) - self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - self._pushers_id_gen = StreamIdGenerator( - db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] - ) self._group_updates_id_gen = StreamIdGenerator( db_conn, "local_group_updates", "stream_id" ) @@ -268,9 +262,6 @@ class DataStore( self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() - # Used in _generate_user_daily_visits to keep track of progress - self._last_user_visit_update = self._get_start_of_day() - def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @@ -289,7 +280,6 @@ class DataStore( " last_user_sync_ts, status_msg, currently_active FROM presence_stream" " WHERE state != ?" ) - sql = self.database_engine.convert_param_style(sql) txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) @@ -301,192 +291,6 @@ class DataStore( return [UserPresenceState(**row) for row in rows] - async def count_daily_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 24 hours. - """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return await self.db_pool.runInteraction( - "count_daily_users", self._count_users, yesterday - ) - - async def count_monthly_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 30 days. - Note this method is intended for phonehome metrics only and is different - from the mau figure in synapse.storage.monthly_active_users which, - amongst other things, includes a 3 day grace period before a user counts. - """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return await self.db_pool.runInteraction( - "count_monthly_users", self._count_users, thirty_days_ago - ) - - def _count_users(self, txn, time_from): - """ - Returns number of users seen in the past time_from period - """ - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - txn.execute(sql, (time_from,)) - (count,) = txn.fetchone() - return count - - async def count_r30_users(self) -> Dict[str, int]: - """ - Counts the number of 30 day retained users, defined as:- - * Users who have created their accounts more than 30 days ago - * Where last seen at most 30 days ago - * Where account creation and last_seen are > 30 days apart - - Returns: - A mapping of counts globally as well as broken out by platform. - """ - - def _count_r30_users(txn): - thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) - thirty_days_ago_in_secs = now - thirty_days_in_secs - - sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( - SELECT - users.name, platform, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen, - CASE - WHEN user_agent LIKE '%%Android%%' THEN 'android' - WHEN user_agent LIKE '%%iOS%%' THEN 'ios' - WHEN user_agent LIKE '%%Electron%%' THEN 'electron' - WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' - WHEN user_agent LIKE '%%Gecko%%' THEN 'web' - ELSE 'unknown' - END - AS platform - FROM user_ips - ) uip - ON users.name = uip.user_id - AND users.appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, platform, users.creation_ts - ) u GROUP BY platform - """ - - results = {} - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - for row in txn: - if row[0] == "unknown": - pass - results[row[0]] = row[1] - - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT users.name, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen - FROM user_ips - ) uip - ON users.name = uip.user_id - AND appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, users.creation_ts - ) u - """ - - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - (count,) = txn.fetchone() - results["all"] = count - - return results - - return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) - - def _get_start_of_day(self): - """ - Returns millisecond unixtime for start of UTC day. - """ - now = time.gmtime() - today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) - return today_start * 1000 - - async def generate_user_daily_visits(self) -> None: - """ - Generates daily visit data for use in cohort/ retention analysis - """ - - def _generate_user_daily_visits(txn): - logger.info("Calling _generate_user_daily_visits") - today_start = self._get_start_of_day() - a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self.clock.time_msec() - - sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? - FROM user_ips AS u - LEFT JOIN ( - SELECT user_id, device_id, timestamp FROM user_daily_visits - WHERE timestamp = ? - ) udv - ON u.user_id = udv.user_id AND u.device_id=udv.device_id - INNER JOIN users ON users.name=u.user_id - WHERE last_seen > ? AND last_seen <= ? - AND udv.timestamp IS NULL AND users.is_guest=0 - AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id - """ - - # This means that the day has rolled over but there could still - # be entries from the previous day. There is an edge case - # where if the user logs in at 23:59 and overwrites their - # last_seen at 00:01 then they will not be counted in the - # previous day's stats - it is important that the query is run - # often to minimise this case. - if today_start > self._last_user_visit_update: - yesterday_start = today_start - a_day_in_milliseconds - txn.execute( - sql, - ( - yesterday_start, - yesterday_start, - self._last_user_visit_update, - today_start, - ), - ) - self._last_user_visit_update = today_start - - txn.execute( - sql, (today_start, today_start, self._last_user_visit_update, now) - ) - # Update _last_user_visit_update to now. The reason to do this - # rather just clamping to the beginning of the day is to limit - # the size of the join - meaning that the query can be run more - # frequently - self._last_user_visit_update = now - - await self.db_pool.runInteraction( - "generate_user_daily_visits", _generate_user_daily_visits - ) - async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. @@ -535,12 +339,13 @@ class DataStore( filters = [] args = [self.hs.config.server_name] + # `name` is in database already in lower case if name: - filters.append("(name LIKE ? OR displayname LIKE ?)") - args.extend(["@%" + name + "%:%", "%" + name + "%"]) + filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)") + args.extend(["@%" + name.lower() + "%:%", "%" + name.lower() + "%"]) elif user_id: filters.append("name LIKE ?") - args.extend(["%" + user_id + "%"]) + args.extend(["%" + user_id.lower() + "%"]) if not guests: filters.append("is_guest = 0") diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index ef81d73573..49ee23470d 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -18,6 +18,7 @@ import abc import logging from typing import Dict, List, Optional, Tuple +from synapse.api.constants import AccountDataTypes from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator @@ -291,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext ) -> bool: ignored_account_data = await self.get_global_account_data_by_type_for_user( - "m.ignored_user_list", + AccountDataTypes.IGNORED_USER_LIST, ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: return False - return ignored_user_id in ignored_account_data.get("ignored_users", {}) + try: + return ignored_user_id in ignored_account_data.get("ignored_users", {}) + except TypeError: + # The type of the ignored_users field is invalid. + return False class AccountDataStore(AccountDataWorkerStore): diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 85f6b1e3fd..e550cbc866 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -15,18 +15,31 @@ # limitations under the License. import logging import re +from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple -from synapse.appservice import AppServiceTransaction +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + AppServiceTransaction, +) from synapse.config.appservice import load_appservices +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.types import Connection +from synapse.types import JsonDict from synapse.util import json_encoder +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) -def _make_exclusive_regex(services_cache): +def _make_exclusive_regex( + services_cache: List[ApplicationService], +) -> Optional[Pattern]: # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ @@ -36,17 +49,19 @@ def _make_exclusive_regex(services_cache): ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) - exclusive_user_regex = re.compile(exclusive_user_regex) + exclusive_user_pattern = re.compile( + exclusive_user_regex + ) # type: Optional[Pattern] else: # We handle this case specially otherwise the constructed regex # will always match - exclusive_user_regex = None + exclusive_user_pattern = None - return exclusive_user_regex + return exclusive_user_pattern class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) @@ -57,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): def get_app_services(self): return self.services_cache - def get_if_app_services_interested_in_user(self, user_id): + def get_if_app_services_interested_in_user(self, user_id: str) -> bool: """Check if the user is one associated with an app service (exclusively) """ if self.exclusive_user_regex: @@ -65,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): else: return False - def get_app_service_by_user_id(self, user_id): + def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]: """Retrieve an application service from their user ID. All application services have associated with them a particular user ID. @@ -74,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore): a user ID to an application service. Args: - user_id(str): The user ID to see if it is an application service. + user_id: The user ID to see if it is an application service. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.sender == user_id: return service return None - def get_app_service_by_token(self, token): + def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]: """Get the application service with the given appservice token. Args: - token (str): The application service token. + token: The application service token. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.token == token: return service return None - def get_app_service_by_id(self, as_id): + def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]: """Get the application service with the given appservice ID. Args: - as_id (str): The application service ID. + as_id: The application service ID. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.id == as_id: @@ -121,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - async def get_appservices_by_state(self, state): + async def get_appservices_by_state( + self, state: ApplicationServiceState + ) -> List[ApplicationService]: """Get a list of application services based on their state. Args: - state(ApplicationServiceState): The state to filter on. + state: The state to filter on. Returns: A list of ApplicationServices, which may be empty. """ @@ -142,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore( services.append(service) return services - async def get_appservice_state(self, service): + async def get_appservice_state( + self, service: ApplicationService + ) -> Optional[ApplicationServiceState]: """Get the application service state. Args: - service(ApplicationService): The service whose state to set. + service: The service whose state to set. Returns: - An ApplicationServiceState. + An ApplicationServiceState or none. """ result = await self.db_pool.simple_select_one( "application_services_state", @@ -161,26 +180,36 @@ class ApplicationServiceTransactionWorkerStore( return result.get("state") return None - async def set_appservice_state(self, service, state) -> None: + async def set_appservice_state( + self, service: ApplicationService, state: ApplicationServiceState + ) -> None: """Set the application service state. Args: - service(ApplicationService): The service whose state to set. - state(ApplicationServiceState): The connectivity state to apply. + service: The service whose state to set. + state: The connectivity state to apply. """ await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} ) - async def create_appservice_txn(self, service, events): + async def create_appservice_txn( + self, + service: ApplicationService, + events: List[EventBase], + ephemeral: List[JsonDict], + ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service - with the given list of events. + with the given list of events. Ephemeral events are NOT persisted to the + database and are not resent if a transaction is retried. Args: - service(ApplicationService): The service who the transaction is for. - events(list<Event>): A list of events to put in the transaction. + service: The service who the transaction is for. + events: A list of persistent events to put in the transaction. + ephemeral: A list of ephemeral events to put in the transaction. + Returns: - AppServiceTransaction: A new transaction. + A new transaction. """ def _create_appservice_txn(txn): @@ -207,19 +236,22 @@ class ApplicationServiceTransactionWorkerStore( "VALUES(?,?,?)", (service.id, new_txn_id, event_ids), ) - return AppServiceTransaction(service=service, id=new_txn_id, events=events) + return AppServiceTransaction( + service=service, id=new_txn_id, events=events, ephemeral=ephemeral + ) return await self.db_pool.runInteraction( "create_appservice_txn", _create_appservice_txn ) - async def complete_appservice_txn(self, txn_id, service) -> None: + async def complete_appservice_txn( + self, txn_id: int, service: ApplicationService + ) -> None: """Completes an application service transaction. Args: - txn_id(str): The transaction ID being completed. - service(ApplicationService): The application service which was sent - this transaction. + txn_id: The transaction ID being completed. + service: The application service which was sent this transaction. """ txn_id = int(txn_id) @@ -259,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore( "complete_appservice_txn", _complete_appservice_txn ) - async def get_oldest_unsent_txn(self, service): - """Get the oldest transaction which has not been sent for this - service. + async def get_oldest_unsent_txn( + self, service: ApplicationService + ) -> Optional[AppServiceTransaction]: + """Get the oldest transaction which has not been sent for this service. Args: - service(ApplicationService): The app service to get the oldest txn. + service: The app service to get the oldest txn. Returns: An AppServiceTransaction or None. """ @@ -296,9 +329,11 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) - return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) + return AppServiceTransaction( + service=service, id=entry["txn_id"], events=events, ephemeral=[] + ) - def _get_last_txn(self, txn, service_id): + def _get_last_txn(self, txn, service_id: Optional[str]) -> int: txn.execute( "SELECT last_txn FROM application_services_state WHERE as_id=?", (service_id,), @@ -309,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore( else: return int(last_txn_id[0]) # select 'last_txn' col - async def set_appservice_last_pos(self, pos) -> None: + async def set_appservice_last_pos(self, pos: int) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) @@ -319,8 +354,10 @@ class ApplicationServiceTransactionWorkerStore( "set_appservice_last_pos", set_appservice_last_pos_txn ) - async def get_new_events_for_appservice(self, current_id, limit): - """Get all new evnets""" + async def get_new_events_for_appservice( + self, current_id: int, limit: int + ) -> Tuple[int, List[EventBase]]: + """Get all new events for an appservice""" def get_new_events_for_appservice_txn(txn): sql = ( @@ -351,6 +388,54 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, events + async def get_type_stream_id_for_appservice( + self, service: ApplicationService, type: str + ) -> int: + if type not in ("read_receipt", "presence"): + raise ValueError( + "Expected type to be a valid application stream id type, got %s" + % (type,) + ) + + def get_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type + txn.execute( + # We do NOT want to escape `stream_id_type`. + "SELECT %s FROM application_services_state WHERE as_id=?" + % stream_id_type, + (service.id,), + ) + last_stream_id = txn.fetchone() + if last_stream_id is None or last_stream_id[0] is None: # no row exists + return 0 + else: + return int(last_stream_id[0]) + + return await self.db_pool.runInteraction( + "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn + ) + + async def set_type_stream_id_for_appservice( + self, service: ApplicationService, type: str, pos: Optional[int] + ) -> None: + if type not in ("read_receipt", "presence"): + raise ValueError( + "Expected type to be a valid application stream id type, got %s" + % (type,) + ) + + def set_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type + txn.execute( + "UPDATE application_services_state SET %s = ? WHERE as_id=?" + % stream_id_type, + (pos, service.id), + ) + + await self.db_pool.runInteraction( + "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn + ) + class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): # This is currently empty due to there not being any AS storage functions diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index f211ddbaf8..3e26d5ba87 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py
@@ -17,12 +17,12 @@ import logging from typing import TYPE_CHECKING from synapse.events.utils import prune_event_dict -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.storage.databases.main.events import encode_json from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util import json_encoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -35,14 +35,13 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) - def _censor_redactions(): - return run_as_background_process( - "_censor_redactions", self._censor_redactions - ) - - if self.hs.config.redaction_retention_period is not None: - hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + if ( + hs.config.run_background_tasks + and self.hs.config.redaction_retention_period is not None + ): + hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) + @wrap_as_background_process("_censor_redactions") async def _censor_redactions(self): """Censors all redactions older than the configured period that haven't been censored yet. @@ -105,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = encode_json( + pruned_json = json_encoder.encode( prune_event_dict( original_event.room_version, original_event.get_dict() ) @@ -171,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase return # Prune the event's dict then convert it to JSON. - pruned_json = encode_json( + pruned_json = json_encoder.encode( prune_event_dict(event.room_version, event.get_dict()) ) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 239c7a949c..e96a8b3f43 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py
@@ -14,12 +14,13 @@ # limitations under the License. import logging -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.descriptors import Cache +from synapse.types import UserID +from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -351,16 +352,70 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return updated -class ClientIpStore(ClientIpBackgroundUpdateStore): +class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.user_ips_max_age = hs.config.user_ips_max_age + + if hs.config.run_background_tasks and self.user_ips_max_age: + self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + + @wrap_as_background_process("prune_old_user_ips") + async def _prune_old_user_ips(self): + """Removes entries in user IPs older than the configured period. + """ + + if self.user_ips_max_age is None: + # Nothing to do + return + + if not await self.db_pool.updates.has_completed_background_update( + "devices_last_seen" + ): + # Only start pruning if we have finished populating the devices + # last seen info. + return + + # We do a slightly funky SQL delete to ensure we don't try and delete + # too much at once (as the table may be very large from before we + # started pruning). + # + # This works by finding the max last_seen that is less than the given + # time, but has no more than N rows before it, deleting all rows with + # a lesser last_seen time. (We COALESCE so that the sub-SELECT always + # returns exactly one row). + sql = """ + DELETE FROM user_ips + WHERE last_seen <= ( + SELECT COALESCE(MAX(last_seen), -1) + FROM ( + SELECT last_seen FROM user_ips + WHERE last_seen <= ? + ORDER BY last_seen ASC + LIMIT 5000 + ) AS u + ) + """ - self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 + timestamp = self.clock.time_msec() - self.user_ips_max_age + + def _prune_old_user_ips_txn(txn): + txn.execute(sql, (timestamp,)) + + await self.db_pool.runInteraction( + "_prune_old_user_ips", _prune_old_user_ips_txn ) - super().__init__(database, db_conn, hs) - self.user_ips_max_age = hs.config.user_ips_max_age +class ClientIpStore(ClientIpWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + + self.client_ip_last_seen = LruCache( + cache_name="client_ip_last_seen", keylen=4, max_size=50000 + ) + + super().__init__(database, db_conn, hs) # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) self._batch_row_update = {} @@ -372,9 +427,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): "before", "shutdown", self._update_client_ips_batch ) - if self.user_ips_max_age: - self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - async def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): @@ -391,7 +443,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return - self.client_ip_last_seen.prefill(key, now) + self.client_ip_last_seen.set(key, now) self._batch_row_update[key] = (user_agent, device_id, now) @@ -495,7 +547,9 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): } return ret - async def get_user_ip_and_agents(self, user): + async def get_user_ip_and_agents( + self, user: UserID + ) -> List[Dict[str, Union[str, int]]]: user_id = user.to_string() results = {} @@ -525,49 +579,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): } for (access_token, ip), (user_agent, last_seen) in results.items() ] - - @wrap_as_background_process("prune_old_user_ips") - async def _prune_old_user_ips(self): - """Removes entries in user IPs older than the configured period. - """ - - if self.user_ips_max_age is None: - # Nothing to do - return - - if not await self.db_pool.updates.has_completed_background_update( - "devices_last_seen" - ): - # Only start pruning if we have finished populating the devices - # last seen info. - return - - # We do a slightly funky SQL delete to ensure we don't try and delete - # too much at once (as the table may be very large from before we - # started pruning). - # - # This works by finding the max last_seen that is less than the given - # time, but has no more than N rows before it, deleting all rows with - # a lesser last_seen time. (We COALESCE so that the sub-SELECT always - # returns exactly one row). - sql = """ - DELETE FROM user_ips - WHERE last_seen <= ( - SELECT COALESCE(MAX(last_seen), -1) - FROM ( - SELECT last_seen FROM user_ips - WHERE last_seen <= ? - ORDER BY last_seen ASC - LIMIT 5000 - ) AS u - ) - """ - - timestamp = self.clock.time_msec() - self.user_ips_max_age - - def _prune_old_user_ips_txn(txn): - txn.execute(sql, (timestamp,)) - - await self.db_pool.runInteraction( - "_prune_old_user_ips", _prune_old_user_ips_txn - ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fdf394c612..9097677648 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ from synapse.logging.opentracing import ( trace, whitelisted_homeserver, ) -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -33,8 +33,9 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key -from synapse.util import json_encoder -from synapse.util.caches.descriptors import Cache, cached, cachedList +from synapse.util import json_decoder, json_encoder +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -48,6 +49,46 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + if hs.config.run_background_tasks: + self._clock.looping_call( + self._prune_old_outbound_device_pokes, 60 * 60 * 1000 + ) + + async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: + """Retrieve number of all devices of given users. + Only returns number of devices that are not marked as hidden. + + Args: + user_ids: The IDs of the users which owns devices + Returns: + Number of devices of this users. + """ + + def count_devices_by_users_txn(txn, user_ids): + sql = """ + SELECT count(*) + FROM devices + WHERE + hidden = '0' AND + """ + + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", user_ids + ) + + txn.execute(sql + clause, args) + return txn.fetchone()[0] + + if not user_ids: + return 0 + + return await self.db_pool.runInteraction( + "count_devices_by_users", count_devices_by_users_txn, user_ids + ) + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -698,6 +739,172 @@ class DeviceWorkerStore(SQLBaseStore): _mark_remote_user_device_list_as_unsubscribed_txn, ) + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + # FIXME: make sure device ID still exists in devices table + row = await self.db_pool.simple_select_one( + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcols=["device_id", "device_data"], + allow_none=True, + ) + return ( + (row["device_id"], json_decoder.decode(row["device_data"])) if row else None + ) + + def _store_dehydrated_device_txn( + self, txn, user_id: str, device_id: str, device_data: str + ) -> Optional[str]: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcol="device_id", + allow_none=True, + ) + self.db_pool.simple_upsert_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + values={"device_id": device_id, "device_data": device_data}, + ) + return old_device_id + + async def store_dehydrated_device( + self, user_id: str, device_id: str, device_data: JsonDict + ) -> Optional[str]: + """Store a dehydrated device for a user. + + Args: + user_id: the user that we are storing the device for + device_id: the ID of the dehydrated device + device_data: the dehydrated device information + Returns: + device id of the user's previous dehydrated device, if any + """ + return await self.db_pool.runInteraction( + "store_dehydrated_device_txn", + self._store_dehydrated_device_txn, + user_id, + device_id, + json_encoder.encode(device_data), + ) + + async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: + """Remove a dehydrated device. + + Args: + user_id: the user that the dehydrated device belongs to + device_id: the ID of the dehydrated device + """ + count = await self.db_pool.simple_delete( + "dehydrated_devices", + {"user_id": user_id, "device_id": device_id}, + desc="remove_dehydrated_device", + ) + return count >= 1 + + @wrap_as_background_process("prune_old_outbound_device_pokes") + async def _prune_old_outbound_device_pokes( + self, prune_age: int = 24 * 60 * 60 * 1000 + ) -> None: + """Delete old entries out of the device_lists_outbound_pokes to ensure + that we don't fill up due to dead servers. + + Normally, we try to send device updates as a delta since a previous known point: + this is done by setting the prev_id in the m.device_list_update EDU. However, + for that to work, we have to have a complete record of each change to + each device, which can add up to quite a lot of data. + + An alternative mechanism is that, if the remote server sees that it has missed + an entry in the stream_id sequence for a given user, it will request a full + list of that user's devices. Hence, we can reduce the amount of data we have to + store (and transmit in some future transaction), by clearing almost everything + for a given destination out of the database, and having the remote server + resync. + + All we need to do is make sure we keep at least one row for each + (user, destination) pair, to remind us to send a m.device_list_update EDU for + that user when the destination comes back. It doesn't matter which device + we keep. + """ + yesterday = self._clock.time_msec() - prune_age + + def _prune_txn(txn): + # look for (user, destination) pairs which have an update older than + # the cutoff. + # + # For each pair, we also need to know the most recent stream_id, and + # an arbitrary device_id at that stream_id. + select_sql = """ + SELECT + dlop1.destination, + dlop1.user_id, + MAX(dlop1.stream_id) AS stream_id, + (SELECT MIN(dlop2.device_id) AS device_id FROM + device_lists_outbound_pokes dlop2 + WHERE dlop2.destination = dlop1.destination AND + dlop2.user_id=dlop1.user_id AND + dlop2.stream_id=MAX(dlop1.stream_id) + ) + FROM device_lists_outbound_pokes dlop1 + GROUP BY destination, user_id + HAVING min(ts) < ? AND count(*) > 1 + """ + + txn.execute(select_sql, (yesterday,)) + rows = txn.fetchall() + + if not rows: + return + + logger.info( + "Pruning old outbound device list updates for %i users/destinations: %s", + len(rows), + shortstr((row[0], row[1]) for row in rows), + ) + + # we want to keep the update with the highest stream_id for each user. + # + # there might be more than one update (with different device_ids) with the + # same stream_id, so we also delete all but one rows with the max stream id. + delete_sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND ( + stream_id < ? OR + (stream_id = ? AND device_id != ?) + ) + """ + count = 0 + for (destination, user_id, stream_id, device_id) in rows: + txn.execute( + delete_sql, (destination, user_id, stream_id, stream_id, device_id) + ) + count += txn.rowcount + + # Since we've deleted unsent deltas, we need to remove the entry + # of last successful sent so that the prev_ids are correctly set. + sql = """ + DELETE FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? + """ + txn.executemany(sql, ((row[0], row[1]) for row in rows)) + + logger.info("Pruned %d device list outbound pokes", count) + + await self.db_pool.runInteraction( + "_prune_old_outbound_device_pokes", _prune_txn, + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -830,14 +1037,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache = Cache( - name="device_id_exists", keylen=2, max_entries=10000 + self.device_id_exists_cache = LruCache( + cache_name="device_id_exists", keylen=2, max_size=10000 ) - self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) - async def store_device( - self, user_id: str, device_id: str, initial_device_display_name: str + self, user_id: str, device_id: str, initial_device_display_name: Optional[str] ) -> bool: """Ensure the given device is known; add it to the store if not @@ -879,7 +1084,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) - self.device_id_exists_cache.prefill(key, True) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: raise @@ -955,7 +1160,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def update_remote_device_list_cache_entry( - self, user_id: str, device_id: str, content: JsonDict, stream_id: int + self, user_id: str, device_id: str, content: JsonDict, stream_id: str ) -> None: """Updates a single device in the cache of a remote user's devicelist. @@ -983,7 +1188,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, content: JsonDict, - stream_id: int, + stream_id: str, ) -> None: if content.get("deleted"): self.db_pool.simple_delete_txn( @@ -1193,95 +1398,3 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids ], ) - - def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000): - """Delete old entries out of the device_lists_outbound_pokes to ensure - that we don't fill up due to dead servers. - - Normally, we try to send device updates as a delta since a previous known point: - this is done by setting the prev_id in the m.device_list_update EDU. However, - for that to work, we have to have a complete record of each change to - each device, which can add up to quite a lot of data. - - An alternative mechanism is that, if the remote server sees that it has missed - an entry in the stream_id sequence for a given user, it will request a full - list of that user's devices. Hence, we can reduce the amount of data we have to - store (and transmit in some future transaction), by clearing almost everything - for a given destination out of the database, and having the remote server - resync. - - All we need to do is make sure we keep at least one row for each - (user, destination) pair, to remind us to send a m.device_list_update EDU for - that user when the destination comes back. It doesn't matter which device - we keep. - """ - yesterday = self._clock.time_msec() - prune_age - - def _prune_txn(txn): - # look for (user, destination) pairs which have an update older than - # the cutoff. - # - # For each pair, we also need to know the most recent stream_id, and - # an arbitrary device_id at that stream_id. - select_sql = """ - SELECT - dlop1.destination, - dlop1.user_id, - MAX(dlop1.stream_id) AS stream_id, - (SELECT MIN(dlop2.device_id) AS device_id FROM - device_lists_outbound_pokes dlop2 - WHERE dlop2.destination = dlop1.destination AND - dlop2.user_id=dlop1.user_id AND - dlop2.stream_id=MAX(dlop1.stream_id) - ) - FROM device_lists_outbound_pokes dlop1 - GROUP BY destination, user_id - HAVING min(ts) < ? AND count(*) > 1 - """ - - txn.execute(select_sql, (yesterday,)) - rows = txn.fetchall() - - if not rows: - return - - logger.info( - "Pruning old outbound device list updates for %i users/destinations: %s", - len(rows), - shortstr((row[0], row[1]) for row in rows), - ) - - # we want to keep the update with the highest stream_id for each user. - # - # there might be more than one update (with different device_ids) with the - # same stream_id, so we also delete all but one rows with the max stream id. - delete_sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND user_id = ? AND ( - stream_id < ? OR - (stream_id = ? AND device_id != ?) - ) - """ - count = 0 - for (destination, user_id, stream_id, device_id) in rows: - txn.execute( - delete_sql, (destination, user_id, stream_id, stream_id, device_id) - ) - count += txn.rowcount - - # Since we've deleted unsent deltas, we need to remove the entry - # of last successful sent so that the prev_ids are correctly set. - sql = """ - DELETE FROM device_lists_outbound_last_success - WHERE destination = ? AND user_id = ? - """ - txn.executemany(sql, ((row[0], row[1]) for row in rows)) - - logger.info("Pruned %d device list outbound pokes", count) - - return run_as_background_process( - "prune_old_outbound_device_pokes", - self.db_pool.runInteraction, - "_prune_old_outbound_device_pokes", - _prune_txn, - ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 22e1ed15d0..4d1b92d1aa 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.types import Cursor from synapse.types import JsonDict from synapse.util import json_encoder @@ -33,6 +33,7 @@ from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.handlers.e2e_keys import SignatureListItem + from synapse.server import HomeServer @attr.s(slots=True) @@ -47,7 +48,20 @@ class DeviceKeyLookupResult: keys = attr.ib(type=Optional[JsonDict]) -class EndToEndKeyWorkerStore(SQLBaseStore): +class EndToEndKeyBackgroundStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "e2e_cross_signing_keys_idx", + index_name="e2e_cross_signing_keys_stream_idx", + table="e2e_cross_signing_keys", + columns=["stream_id"], + unique=True, + ) + + +class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def get_e2e_device_keys_for_federation_query( self, user_id: str ) -> Tuple[int, List[JsonDict]]: @@ -367,6 +381,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def set_e2e_fallback_keys( + self, user_id: str, device_id: str, fallback_keys: JsonDict + ) -> None: + """Set the user's e2e fallback keys. + + Args: + user_id: the user whose keys are being set + device_id: the device whose keys are being set + fallback_keys: the keys to set. This is a map from key ID (which is + of the form "algorithm:id") to key data. + """ + # fallback_keys will usually only have one item in it, so using a for + # loop (as opposed to calling simple_upsert_many_txn) won't be too bad + # FIXME: make sure that only one key per algorithm is uploaded + for key_id, fallback_key in fallback_keys.items(): + algorithm, key_id = key_id.split(":", 1) + await self.db_pool.simple_upsert( + "e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + values={ + "key_id": key_id, + "key_json": json_encoder.encode(fallback_key), + "used": False, + }, + desc="set_e2e_fallback_key", + ) + + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) + ) + + @cached(max_entries=10000) + async def get_e2e_unused_fallback_key_types( + self, user_id: str, device_id: str + ) -> List[str]: + """Returns the fallback key types that have an unused key. + + Args: + user_id: the user whose keys are being queried + device_id: the device whose keys are being queried + + Returns: + a list of key types + """ + return await self.db_pool.simple_select_onecol( + "e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id, "used": False}, + retcol="algorithm", + desc="get_e2e_unused_fallback_key_types", + ) + async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None ) -> Optional[dict]: @@ -701,15 +770,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): " WHERE user_id = ? AND device_id = ? AND algorithm = ?" " LIMIT 1" ) + fallback_sql = ( + "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) result = {} delete = [] + used_fallbacks = [] for user_id, device_id, algorithm in query_list: user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) - for key_id, key_json in txn: + otk_row = txn.fetchone() + if otk_row is not None: + key_id, key_json = otk_row device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) + else: + # no one-time key available, so see if there's a fallback + # key + txn.execute(fallback_sql, (user_id, device_id, algorithm)) + fallback_row = txn.fetchone() + if fallback_row is not None: + key_id, key_json, used = fallback_row + device_result[algorithm + ":" + key_id] = key_json + if not used: + used_fallbacks.append( + (user_id, device_id, algorithm, key_id) + ) + + # drop any one-time keys that were claimed sql = ( "DELETE FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" @@ -726,6 +817,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + # mark fallback keys as used + for user_id, device_id, algorithm, key_id in used_fallbacks: + self.db_pool.simple_update_txn( + txn, + "e2e_fallback_keys_json", + { + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, + {"used": True}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + ) + return result return await self.db_pool.runInteraction( @@ -754,6 +862,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + self.db_pool.simple_delete_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + ) await self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6d3689c09e..ebffd89251 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -19,19 +19,33 @@ from typing import Dict, Iterable, List, Set, Tuple from synapse.api.errors import StoreError from synapse.events import EventBase -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.types import Collection from synapse.util.caches.descriptors import cached +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + if hs.config.run_background_tasks: + hs.get_clock().looping_call( + self._delete_old_forward_extrem_cache, 60 * 60 * 1000 + ) + + # Cache of event ID to list of auth event IDs and their depths. + self._event_auth_cache = LruCache( + 500000, "_event_auth_cache", size_callback=len + ) # type: LruCache[str, List[Tuple[str, int]]] + async def get_auth_chain( self, event_ids: Collection[str], include_given: bool = False ) -> List[EventBase]: @@ -76,17 +90,45 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: results = set() - base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE " + # We pull out the depth simply so that we can populate the + # `_event_auth_cache` cache. + base_sql = """ + SELECT a.event_id, auth_id, depth + FROM event_auth AS a + INNER JOIN events AS e ON (e.event_id = a.auth_id) + WHERE + """ front = set(event_ids) while front: new_front = set() for chunk in batch_iter(front, 100): - clause, args = make_in_list_sql_clause( - txn.database_engine, "event_id", chunk - ) - txn.execute(base_sql + clause, args) - new_front.update(r[0] for r in txn) + # Pull the auth events either from the cache or DB. + to_fetch = [] # Event IDs to fetch from DB # type: List[str] + for event_id in chunk: + res = self._event_auth_cache.get(event_id) + if res is None: + to_fetch.append(event_id) + else: + new_front.update(auth_id for auth_id, depth in res) + + if to_fetch: + clause, args = make_in_list_sql_clause( + txn.database_engine, "a.event_id", to_fetch + ) + txn.execute(base_sql + clause, args) + + # Note we need to batch up the results by event ID before + # adding to the cache. + to_cache = {} + for event_id, auth_event_id, auth_event_depth in txn: + to_cache.setdefault(event_id, []).append( + (auth_event_id, auth_event_depth) + ) + new_front.add(auth_event_id) + + for event_id, auth_events in to_cache.items(): + self._event_auth_cache.set(event_id, auth_events) new_front -= results @@ -95,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) - async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]: + async def get_auth_chain_difference( + self, room_id: str, state_sets: List[Set[str]] + ) -> Set[str]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -205,14 +249,38 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas break # Fetch the auth events and their depths of the N last events we're - # currently walking + # currently walking, either from cache or DB. search, chunk = search[:-100], search[-100:] - clause, args = make_in_list_sql_clause( - txn.database_engine, "a.event_id", [e_id for _, e_id in chunk] - ) - txn.execute(base_sql + clause, args) - for event_id, auth_event_id, auth_event_depth in txn: + found = [] # Results found # type: List[Tuple[str, str, int]] + to_fetch = [] # Event IDs to fetch from DB # type: List[str] + for _, event_id in chunk: + res = self._event_auth_cache.get(event_id) + if res is None: + to_fetch.append(event_id) + else: + found.extend((event_id, auth_id, depth) for auth_id, depth in res) + + if to_fetch: + clause, args = make_in_list_sql_clause( + txn.database_engine, "a.event_id", to_fetch + ) + txn.execute(base_sql + clause, args) + + # We parse the results and add the to the `found` set and the + # cache (note we need to batch up the results by event ID before + # adding to the cache). + to_cache = {} + for event_id, auth_event_id, auth_event_depth in txn: + to_cache.setdefault(event_id, []).append( + (auth_event_id, auth_event_depth) + ) + found.append((event_id, auth_event_id, auth_event_depth)) + + for event_id, auth_events in to_cache.items(): + self._event_auth_cache.set(event_id, auth_events) + + for event_id, auth_event_id, auth_event_depth in found: event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) sets = event_to_missing_sets.get(auth_event_id) @@ -586,6 +654,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return [row["event_id"] for row in rows] + @wrap_as_background_process("delete_old_forward_extrem_cache") + async def _delete_old_forward_extrem_cache(self) -> None: + def _delete_old_forward_extrem_cache_txn(txn): + # Delete entries older than a month, while making sure we don't delete + # the only entries for a room. + sql = """ + DELETE FROM stream_ordering_to_exterm + WHERE + room_id IN ( + SELECT room_id + FROM stream_ordering_to_exterm + WHERE stream_ordering > ? + ) AND stream_ordering < ? + """ + txn.execute( + sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) + ) + + await self.db_pool.runInteraction( + "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn, + ) + class EventFederationStore(EventFederationWorkerStore): """ Responsible for storing and serving up the various graphs associated @@ -606,34 +696,6 @@ class EventFederationStore(EventFederationWorkerStore): self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) - hs.get_clock().looping_call( - self._delete_old_forward_extrem_cache, 60 * 60 * 1000 - ) - - def _delete_old_forward_extrem_cache(self): - def _delete_old_forward_extrem_cache_txn(txn): - # Delete entries older than a month, while making sure we don't delete - # the only entries for a room. - sql = """ - DELETE FROM stream_ordering_to_exterm - WHERE - room_id IN ( - SELECT room_id - FROM stream_ordering_to_exterm - WHERE stream_ordering > ? - ) AND stream_ordering < ? - """ - txn.execute( - sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) - ) - - return run_as_background_process( - "delete_old_forward_extrem_cache", - self.db_pool.runInteraction, - "_delete_old_forward_extrem_cache", - _delete_old_forward_extrem_cache_txn, - ) - async def clean_room_for_join(self, room_id): return await self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 62f1738732..e5c03cc609 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,15 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Dict, List, Optional, Tuple, Union import attr -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -74,19 +73,21 @@ class EventPushActionsWorkerStore(SQLBaseStore): self.stream_ordering_month_ago = None self.stream_ordering_day_ago = None - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - ) + cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn") self._find_stream_orderings_for_times_txn(cur) cur.close() self.find_stream_orderings_looping_call = self._clock.looping_call( self._find_stream_orderings_for_times, 10 * 60 * 1000 ) + self._rotate_delay = 3 self._rotate_count = 10000 + self._doing_notif_rotation = False + if hs.config.run_background_tasks: + self._rotate_notif_loop = self._clock.looping_call( + self._rotate_notifs, 30 * 60 * 1000 + ) @cached(num_args=3, tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( @@ -518,15 +519,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): "Error removing push actions after event persistence failure" ) - def _find_stream_orderings_for_times(self): - return run_as_background_process( - "event_push_action_stream_orderings", - self.db_pool.runInteraction, + @wrap_as_background_process("event_push_action_stream_orderings") + async def _find_stream_orderings_for_times(self) -> None: + await self.db_pool.runInteraction( "_find_stream_orderings_for_times", self._find_stream_orderings_for_times_txn, ) - def _find_stream_orderings_for_times_txn(self, txn): + def _find_stream_orderings_for_times_txn(self, txn: LoggingTransaction) -> None: logger.info("Searching for stream ordering 1 month ago") self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 @@ -656,129 +656,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) return result[0] if result else None - -class EventPushActionsStore(EventPushActionsWorkerStore): - EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self.db_pool.updates.register_background_index_update( - self.EPA_HIGHLIGHT_INDEX, - index_name="event_push_actions_u_highlight", - table="event_push_actions", - columns=["user_id", "stream_ordering"], - ) - - self.db_pool.updates.register_background_index_update( - "event_push_actions_highlights_index", - index_name="event_push_actions_highlights_index", - table="event_push_actions", - columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1", - ) - - self._doing_notif_rotation = False - self._rotate_notif_loop = self._clock.looping_call( - self._start_rotate_notifs, 30 * 60 * 1000 - ) - - async def get_push_actions_for_user( - self, user_id, before=None, limit=50, only_highlight=False - ): - def f(txn): - before_clause = "" - if before: - before_clause = "AND epa.stream_ordering < ?" - args = [user_id, before, limit] - else: - args = [user_id, limit] - - if only_highlight: - if len(before_clause) > 0: - before_clause += " " - before_clause += "AND epa.highlight = 1" - - # NB. This assumes event_ids are globally unique since - # it makes the query easier to index - sql = ( - "SELECT epa.event_id, epa.room_id," - " epa.stream_ordering, epa.topological_ordering," - " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" - " FROM event_push_actions epa, events e" - " WHERE epa.event_id = e.event_id" - " AND epa.user_id = ? %s" - " AND epa.notif = 1" - " ORDER BY epa.stream_ordering DESC" - " LIMIT ?" % (before_clause,) - ) - txn.execute(sql, args) - return self.db_pool.cursor_to_dict(txn) - - push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) - for pa in push_actions: - pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) - return push_actions - - async def get_latest_push_action_stream_ordering(self): - def f(txn): - txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") - return txn.fetchone() - - result = await self.db_pool.runInteraction( - "get_latest_push_action_stream_ordering", f - ) - return result[0] or 0 - - def _remove_old_push_actions_before_txn( - self, txn, room_id, user_id, stream_ordering - ): - """ - Purges old push actions for a user and room before a given - stream_ordering. - - We however keep a months worth of highlighted notifications, so that - users can still get a list of recent highlights. - - Args: - txn: The transcation - room_id: Room ID to delete from - user_id: user ID to delete for - stream_ordering: The lowest stream ordering which will - not be deleted. - """ - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id, user_id), - ) - - # We need to join on the events table to get the received_ts for - # event_push_actions and sqlite won't let us use a join in a delete so - # we can't just delete where received_ts < x. Furthermore we can - # only identify event_push_actions by a tuple of room_id, event_id - # we we can't use a subquery. - # Instead, we look up the stream ordering for the last event in that - # room received before the threshold time and delete event_push_actions - # in the room with a stream_odering before that. - txn.execute( - "DELETE FROM event_push_actions " - " WHERE user_id = ? AND room_id = ? AND " - " stream_ordering <= ?" - " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", - (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), - ) - - txn.execute( - """ - DELETE FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? - """, - (room_id, user_id, stream_ordering), - ) - - def _start_rotate_notifs(self): - return run_as_background_process("rotate_notifs", self._rotate_notifs) - + @wrap_as_background_process("rotate_notifs") async def _rotate_notifs(self): if self._doing_notif_rotation or self.stream_ordering_day_ago is None: return @@ -958,6 +836,111 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) +class EventPushActionsStore(EventPushActionsWorkerStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.db_pool.updates.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1", + ) + + async def get_push_actions_for_user( + self, user_id, before=None, limit=50, only_highlight=False + ): + def f(txn): + before_clause = "" + if before: + before_clause = "AND epa.stream_ordering < ?" + args = [user_id, before, limit] + else: + args = [user_id, limit] + + if only_highlight: + if len(before_clause) > 0: + before_clause += " " + before_clause += "AND epa.highlight = 1" + + # NB. This assumes event_ids are globally unique since + # it makes the query easier to index + sql = ( + "SELECT epa.event_id, epa.room_id," + " epa.stream_ordering, epa.topological_ordering," + " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" + " FROM event_push_actions epa, events e" + " WHERE epa.event_id = e.event_id" + " AND epa.user_id = ? %s" + " AND epa.notif = 1" + " ORDER BY epa.stream_ordering DESC" + " LIMIT ?" % (before_clause,) + ) + txn.execute(sql, args) + return self.db_pool.cursor_to_dict(txn) + + push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) + for pa in push_actions: + pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) + return push_actions + + def _remove_old_push_actions_before_txn( + self, txn, room_id, user_id, stream_ordering + ): + """ + Purges old push actions for a user and room before a given + stream_ordering. + + We however keep a months worth of highlighted notifications, so that + users can still get a list of recent highlights. + + Args: + txn: The transcation + room_id: Room ID to delete from + user_id: user ID to delete for + stream_ordering: The lowest stream ordering which will + not be deleted. + """ + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id, user_id), + ) + + # We need to join on the events table to get the received_ts for + # event_push_actions and sqlite won't let us use a join in a delete so + # we can't just delete where received_ts < x. Furthermore we can + # only identify event_push_actions by a tuple of room_id, event_id + # we we can't use a subquery. + # Instead, we look up the stream ordering for the last event in that + # room received before the threshold time and delete event_push_actions + # in the room with a stream_odering before that. + txn.execute( + "DELETE FROM event_push_actions " + " WHERE user_id = ? AND room_id = ? AND " + " stream_ordering <= ?" + " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", + (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), + ) + + txn.execute( + """ + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, + (room_id, user_id, stream_ordering), + ) + + def _action_has_highlight(actions): for action in actions: try: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 18def01f50..90fb1a1f00 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import StateMap, get_domain_from_id -from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util import json_encoder from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -52,16 +52,6 @@ event_counter = Counter( ) -def encode_json(json_object): - """ - Encode a Python object as JSON and return it in a Unicode string. - """ - out = frozendict_json_encoder.encode(json_object) - if isinstance(out, bytes): - out = out.decode("utf8") - return out - - _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) @@ -341,6 +331,10 @@ class PersistEventsStore: min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + # stream orderings should have been assigned by now + assert min_stream_order + assert max_stream_order + self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremeties, @@ -367,6 +361,8 @@ class PersistEventsStore: self._store_event_txn(txn, events_and_contexts=events_and_contexts) + self._persist_transaction_ids_txn(txn, events_and_contexts) + # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) @@ -411,6 +407,35 @@ class PersistEventsStore: # room_memberships, where applicable. self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + def _persist_transaction_ids_txn( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ): + """Persist the mapping from transaction IDs to event IDs (if defined). + """ + + to_insert = [] + for event, _ in events_and_contexts: + token_id = getattr(event.internal_metadata, "token_id", None) + txn_id = getattr(event.internal_metadata, "txn_id", None) + if token_id and txn_id: + to_insert.append( + { + "event_id": event.event_id, + "room_id": event.room_id, + "user_id": event.sender, + "token_id": token_id, + "txn_id": txn_id, + "inserted_ts": self._clock.time_msec(), + } + ) + + if to_insert: + self.db_pool.simple_insert_many_txn( + txn, table="event_txn_id", values=to_insert, + ) + def _update_current_state_txn( self, txn: LoggingTransaction, @@ -432,12 +457,12 @@ class PersistEventsStore: # so that async background tasks get told what happened. sql = """ INSERT INTO current_state_delta_stream - (stream_id, room_id, type, state_key, event_id, prev_event_id) - SELECT ?, room_id, type, state_key, null, event_id + (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, ?, room_id, type, state_key, null, event_id FROM current_state_events WHERE room_id = ? """ - txn.execute(sql, (stream_id, room_id)) + txn.execute(sql, (stream_id, self._instance_name, room_id)) self.db_pool.simple_delete_txn( txn, table="current_state_events", keyvalues={"room_id": room_id}, @@ -458,8 +483,8 @@ class PersistEventsStore: # sql = """ INSERT INTO current_state_delta_stream - (stream_id, room_id, type, state_key, event_id, prev_event_id) - SELECT ?, ?, ?, ?, ?, ( + (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, ?, ?, ?, ?, ?, ( SELECT event_id FROM current_state_events WHERE room_id = ? AND type = ? AND state_key = ? ) @@ -469,6 +494,7 @@ class PersistEventsStore: ( ( stream_id, + self._instance_name, room_id, etype, state_key, @@ -743,7 +769,7 @@ class PersistEventsStore: logger.exception("") raise - metadata_json = encode_json(event.internal_metadata.get_dict()) + metadata_json = json_encoder.encode(event.internal_metadata.get_dict()) sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" txn.execute(sql, (metadata_json, event.event_id)) @@ -759,6 +785,7 @@ class PersistEventsStore: "event_stream_ordering": stream_order, "event_id": event.event_id, "state_group": state_group_id, + "instance_name": self._instance_name, }, ) @@ -797,10 +824,10 @@ class PersistEventsStore: { "event_id": event.event_id, "room_id": event.room_id, - "internal_metadata": encode_json( + "internal_metadata": json_encoder.encode( event.internal_metadata.get_dict() ), - "json": encode_json(event_dict(event)), + "json": json_encoder.encode(event_dict(event)), "format_version": event.format_version, } for event, _ in events_and_contexts @@ -1022,9 +1049,7 @@ class PersistEventsStore: def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.prefill( - (cache_entry[0].event_id,), cache_entry - ) + self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) txn.call_after(prefill) @@ -1241,6 +1266,10 @@ class PersistEventsStore: ) def _store_retention_policy_for_room_txn(self, txn, event): + if not event.is_state(): + logger.debug("Ignoring non-state m.room.retention event") + return + if hasattr(event, "content") and ( "min_lifetime" in event.content or "max_lifetime" in event.content ): diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 5e4af2eb51..97b6754846 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -92,6 +92,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): where_clause="NOT have_censored", ) + self.db_pool.updates.register_background_index_update( + "users_have_local_media", + index_name="users_have_local_media", + table="local_media_repository", + columns=["user_id", "created_ts"], + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f95679ebc4..4732685f6e 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging import threading @@ -32,9 +31,13 @@ from synapse.api.room_versions import ( RoomVersions, ) from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream @@ -42,8 +45,9 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator -from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached +from synapse.types import Collection, JsonDict, get_domain_from_id +from synapse.util.caches.descriptors import cached +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -74,6 +78,13 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): + # Whether to use dedicated DB threads for event fetching. This is only used + # if there are multiple DB threads available. When used will lock the DB + # thread for periods of time (so unit tests want to disable this when they + # run DB transactions on the main thread). See EVENT_QUEUE_* for more + # options controlling this. + USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -130,11 +141,16 @@ class EventsWorkerStore(SQLBaseStore): db_conn, "events", "stream_ordering", step=-1 ) - self._get_event_cache = Cache( - "*getEvent*", + if hs.config.run_background_tasks: + # We periodically clean out old transaction ID mappings + self._clock.looping_call( + self._cleanup_old_transaction_ids, 5 * 60 * 1000, + ) + + self._get_event_cache = LruCache( + cache_name="*getEvent*", keylen=3, - max_entries=hs.config.caches.event_cache_size, - apply_cache_factor_from_config=False, + max_size=hs.config.caches.event_cache_size, ) self._event_fetch_lock = threading.Condition() @@ -510,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore): return event_map + async def get_stripped_room_state_from_event_context( + self, + context: EventContext, + state_types_to_include: List[EventTypes], + membership_user_id: Optional[str] = None, + ) -> List[JsonDict]: + """ + Retrieve the stripped state from a room, given an event context to retrieve state + from as well as the state types to include. Optionally, include the membership + events from a specific user. + + "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys + are included from each state event. + + Args: + context: The event context to retrieve state of the room from. + state_types_to_include: The type of state events to include. + membership_user_id: An optional user ID to include the stripped membership state + events of. This is useful when generating the stripped state of a room for + invites. We want to send membership events of the inviter, so that the + invitee can display the inviter's profile information if the room lacks any. + + Returns: + A list of dictionaries, each representing a stripped state event from the room. + """ + current_state_ids = await context.get_current_state_ids() + + # We know this event is not an outlier, so this must be + # non-None. + assert current_state_ids is not None + + # The state to include + state_to_include_ids = [ + e_id + for k, e_id in current_state_ids.items() + if k[0] in state_types_to_include + or (membership_user_id and k == (EventTypes.Member, membership_user_id)) + ] + + state_to_include = await self.get_events(state_to_include_ids) + + return [ + { + "type": e.type, + "state_key": e.state_key, + "content": e.content, + "sender": e.sender, + } + for e in state_to_include.values() + ] + def _do_fetch(self, conn): """Takes a database connection and waits for requests for events from the _event_fetch_list queue. @@ -522,7 +589,11 @@ class EventsWorkerStore(SQLBaseStore): if not event_list: single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): self._event_fetch_ongoing -= 1 return else: @@ -712,6 +783,7 @@ class EventsWorkerStore(SQLBaseStore): internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) + original_ev.internal_metadata.stream_ordering = row["stream_ordering"] event_map[event_id] = original_ev @@ -728,7 +800,7 @@ class EventsWorkerStore(SQLBaseStore): event=original_ev, redacted_event=redacted_event ) - self._get_event_cache.prefill((event_id,), cache_entry) + self._get_event_cache.set((event_id,), cache_entry) result_map[event_id] = cache_entry return result_map @@ -779,6 +851,8 @@ class EventsWorkerStore(SQLBaseStore): * event_id (str) + * stream_ordering (int): stream ordering for this event + * json (str): json-encoded event structure * internal_metadata (str): json-encoded internal metadata dict @@ -811,13 +885,15 @@ class EventsWorkerStore(SQLBaseStore): sql = """\ SELECT e.event_id, - e.internal_metadata, - e.json, - e.format_version, + e.stream_ordering, + ej.internal_metadata, + ej.json, + ej.format_version, r.room_version, rej.reason - FROM event_json as e - LEFT JOIN rooms r USING (room_id) + FROM events AS e + JOIN event_json AS ej USING (event_id) + LEFT JOIN rooms r ON r.room_id = e.room_id LEFT JOIN rejections as rej USING (event_id) WHERE """ @@ -831,11 +907,12 @@ class EventsWorkerStore(SQLBaseStore): event_id = row[0] event_dict[event_id] = { "event_id": event_id, - "internal_metadata": row[1], - "json": row[2], - "format_version": row[3], - "room_version_id": row[4], - "rejected_reason": row[5], + "stream_ordering": row[1], + "internal_metadata": row[2], + "json": row[3], + "format_version": row[4], + "room_version_id": row[5], + "rejected_reason": row[6], "redactions": [], } @@ -1017,16 +1094,12 @@ class EventsWorkerStore(SQLBaseStore): return {"v1": complexity_v1} - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - def get_current_events_token(self): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() async def get_all_new_forward_event_rows( - self, last_id: int, current_id: int, limit: int + self, instance_name: str, last_id: int, current_id: int, limit: int ) -> List[Tuple]: """Returns new events, for the Events replication stream @@ -1044,16 +1117,19 @@ class EventsWorkerStore(SQLBaseStore): def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" + " LEFT JOIN room_memberships USING (event_id)" + " LEFT JOIN rejections USING (event_id)" " WHERE ? < stream_ordering AND stream_ordering <= ?" + " AND instance_name = ?" " ORDER BY stream_ordering ASC" " LIMIT ?" ) - txn.execute(sql, (last_id, current_id, limit)) + txn.execute(sql, (last_id, current_id, instance_name, limit)) return txn.fetchall() return await self.db_pool.runInteraction( @@ -1061,7 +1137,7 @@ class EventsWorkerStore(SQLBaseStore): ) async def get_ex_outlier_stream_rows( - self, last_id: int, current_id: int + self, instance_name: str, last_id: int, current_id: int ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream @@ -1078,18 +1154,21 @@ class EventsWorkerStore(SQLBaseStore): def get_ex_outlier_stream_rows_txn(txn): sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" + " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" + " LEFT JOIN room_memberships USING (event_id)" + " LEFT JOIN rejections USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" + " AND out.instance_name = ?" " ORDER BY event_stream_ordering ASC" ) - txn.execute(sql, (last_id, current_id)) + txn.execute(sql, (last_id, current_id, instance_name)) return txn.fetchall() return await self.db_pool.runInteraction( @@ -1102,6 +1181,9 @@ class EventsWorkerStore(SQLBaseStore): """Get updates for backfill replication stream, including all new backfilled events and events that have gone from being outliers to not. + NOTE: The IDs given here are from replication, and so should be + *positive*. + Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. @@ -1132,10 +1214,11 @@ class EventsWorkerStore(SQLBaseStore): " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" + " AND instance_name = ?" " ORDER BY stream_ordering ASC" " LIMIT ?" ) - txn.execute(sql, (-last_id, -current_id, limit)) + txn.execute(sql, (-last_id, -current_id, instance_name, limit)) new_event_updates = [(row[0], row[1:]) for row in txn] limited = False @@ -1149,15 +1232,16 @@ class EventsWorkerStore(SQLBaseStore): "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id" " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" + " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" + " AND out.instance_name = ?" " ORDER BY event_stream_ordering DESC" ) - txn.execute(sql, (-last_id, -upper_bound)) + txn.execute(sql, (-last_id, -upper_bound, instance_name)) new_event_updates.extend((row[0], row[1:]) for row in txn) if len(new_event_updates) >= limit: @@ -1171,7 +1255,7 @@ class EventsWorkerStore(SQLBaseStore): ) async def get_all_updated_current_state_deltas( - self, from_token: int, to_token: int, target_row_count: int + self, instance_name: str, from_token: int, to_token: int, target_row_count: int ) -> Tuple[List[Tuple], int, bool]: """Fetch updates from current_state_delta_stream @@ -1197,9 +1281,10 @@ class EventsWorkerStore(SQLBaseStore): SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream WHERE ? < stream_id AND stream_id <= ? + AND instance_name = ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (from_token, to_token, target_row_count)) + txn.execute(sql, (from_token, to_token, instance_name, target_row_count)) return txn.fetchall() def get_deltas_for_stream_id_txn(txn, stream_id): @@ -1287,3 +1372,77 @@ class EventsWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) + + async def get_event_id_from_transaction_id( + self, room_id: str, user_id: str, token_id: int, txn_id: str + ) -> Optional[str]: + """Look up if we have already persisted an event for the transaction ID, + returning the event ID if so. + """ + return await self.db_pool.simple_select_one_onecol( + table="event_txn_id", + keyvalues={ + "room_id": room_id, + "user_id": user_id, + "token_id": token_id, + "txn_id": txn_id, + }, + retcol="event_id", + allow_none=True, + desc="get_event_id_from_transaction_id", + ) + + async def get_already_persisted_events( + self, events: Iterable[EventBase] + ) -> Dict[str, str]: + """Look up if we have already persisted an event for the transaction ID, + returning a mapping from event ID in the given list to the event ID of + an existing event. + + Also checks if there are duplicates in the given events, if there are + will map duplicates to the *first* event. + """ + + mapping = {} + txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str] + + for event in events: + token_id = getattr(event.internal_metadata, "token_id", None) + txn_id = getattr(event.internal_metadata, "txn_id", None) + + if token_id and txn_id: + # Check if this is a duplicate of an event in the given events. + existing = txn_id_to_event.get((event.room_id, token_id, txn_id)) + if existing: + mapping[event.event_id] = existing + continue + + # Check if this is a duplicate of an event we've already + # persisted. + existing = await self.get_event_id_from_transaction_id( + event.room_id, event.sender, token_id, txn_id + ) + if existing: + mapping[event.event_id] = existing + txn_id_to_event[(event.room_id, token_id, txn_id)] = existing + else: + txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id + + return mapping + + @wrap_as_background_process("_cleanup_old_transaction_ids") + async def _cleanup_old_transaction_ids(self): + """Cleans out transaction id mappings older than 24hrs. + """ + + def _cleanup_old_transaction_ids_txn(txn): + sql = """ + DELETE FROM event_txn_id + WHERE inserted_ts < ? + """ + one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 + txn.execute(sql, (one_day_ago,)) + + return await self.db_pool.runInteraction( + "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, + ) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index ad43bb05ab..04ac2d0ced 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py
@@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes from synapse.storage._base import SQLBaseStore from synapse.storage.keys import FetchKeyResult +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore): ) async def get_server_verify_keys( self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: + ) -> Dict[Tuple[str, str], FetchKeyResult]: """ Args: server_name_and_key_ids: @@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore): """ keys = {} - def _get_keys(txn, batch): + def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None: """Processes a batch of keys to fetch, and adds the result to `keys`.""" # batch_iter always returns tuples so it's safe to do len(batch) @@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore): # `ts_valid_until_ms`. ts_valid_until_ms = 0 - res = FetchKeyResult( + keys[(server_name, key_id)] = FetchKeyResult( verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), valid_until_ts=ts_valid_until_ms, ) - keys[(server_name, key_id)] = res - def _txn(txn): + def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: for batch in batch_iter(server_name_and_key_ids, 50): _get_keys(txn, batch) return keys @@ -122,9 +122,7 @@ class KeyStore(SQLBaseStore): # param, which is itself the 2-tuple (server_name, key_id). invalidations.append((server_name, key_id)) - await self.db_pool.runInteraction( - "store_server_verify_keys", - self.db_pool.simple_upsert_many_txn, + await self.db_pool.simple_upsert_many( table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -135,6 +133,7 @@ class KeyStore(SQLBaseStore): "verify_key", ), value_values=value_values, + desc="store_server_verify_keys", ) invalidate = self._get_server_verify_key.invalidate diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index cc538c5c10..4b2f224718 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -93,6 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + self.server_name = hs.hostname async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media @@ -115,6 +116,109 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_local_media", ) + async def get_local_media_by_user_paginate( + self, start: int, limit: int, user_id: str + ) -> Tuple[List[Dict[str, Any]], int]: + """Get a paginated list of metadata for a local piece of media + which an user_id has uploaded + + Args: + start: offset in the list + limit: maximum amount of media_ids to retrieve + user_id: fully-qualified user id + Returns: + A paginated list of all metadata of user's media, + plus the total count of all the user's media + """ + + def get_local_media_by_user_paginate_txn(txn): + + args = [user_id] + sql = """ + SELECT COUNT(*) as total_media + FROM local_media_repository + WHERE user_id = ? + """ + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = """ + SELECT + "media_id", + "media_type", + "media_length", + "upload_name", + "created_ts", + "last_access_ts", + "quarantined_by", + "safe_from_quarantine" + FROM local_media_repository + WHERE user_id = ? + ORDER BY created_ts DESC, media_id DESC + LIMIT ? OFFSET ? + """ + + args += [limit, start] + txn.execute(sql, args) + media = self.db_pool.cursor_to_dict(txn) + return media, count + + return await self.db_pool.runInteraction( + "get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn + ) + + async def get_local_media_before( + self, before_ts: int, size_gt: int, keep_profiles: bool, + ) -> Optional[List[str]]: + + # to find files that have never been accessed (last_access_ts IS NULL) + # compare with `created_ts` + sql = """ + SELECT media_id + FROM local_media_repository AS lmr + WHERE + ( last_access_ts < ? + OR ( created_ts < ? AND last_access_ts IS NULL ) ) + AND media_length > ? + """ + + if keep_profiles: + sql_keep = """ + AND ( + NOT EXISTS + (SELECT 1 + FROM profiles + WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM groups + WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM room_memberships + WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM user_directory + WHERE user_directory.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM room_stats_state + WHERE room_stats_state.avatar = '{media_prefix}' || lmr.media_id) + ) + """.format( + media_prefix="mxc://%s/" % (self.server_name,), + ) + sql += sql_keep + + def _get_local_media_before_txn(txn): + txn.execute(sql, (before_ts, before_ts, size_gt)) + return [row[0] for row in txn] + + return await self.db_pool.runInteraction( + "get_local_media_before", _get_local_media_before_txn + ) + async def store_local_media( self, media_id, @@ -348,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_remote_media_thumbnails", ) + async def get_remote_media_thumbnail( + self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str, + ) -> Optional[Dict[str, Any]]: + """Fetch the thumbnail info of given width, height and type. + """ + + return await self.db_pool.simple_select_one( + table="remote_media_cache_thumbnails", + keyvalues={ + "media_origin": origin, + "media_id": media_id, + "thumbnail_width": t_width, + "thumbnail_height": t_height, + "thumbnail_type": t_type, + }, + retcols=( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + "filesystem_id", + ), + allow_none=True, + desc="get_remote_media_thumbnail", + ) + async def store_remote_media_thumbnail( self, origin, diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 92099f95ce..ab18cc4d79 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py
@@ -12,15 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import calendar +import logging +import time +from typing import Dict from synapse.metrics import GaugeBucketCollector -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +logger = logging.getLogger(__name__) + # Collect metrics on the number of forward extremities that exist. _extremities_collecter = GaugeBucketCollector( "synapse_forward_extremities", @@ -51,15 +57,13 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes - def read_forward_extremities(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "read_forward_extremities", self._read_forward_extremities - ) + if hs.config.run_background_tasks: + self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000) - hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + # Used in _generate_user_daily_visits to keep track of progress + self._last_user_visit_update = self._get_start_of_day() + @wrap_as_background_process("read_forward_extremities") async def _read_forward_extremities(self): def fetch(txn): txn.execute( @@ -137,3 +141,196 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return count return await self.db_pool.runInteraction("count_daily_active_rooms", _count) + + async def count_daily_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + return await self.db_pool.runInteraction( + "count_daily_users", self._count_users, yesterday + ) + + async def count_monthly_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 30 days. + Note this method is intended for phonehome metrics only and is different + from the mau figure in synapse.storage.monthly_active_users which, + amongst other things, includes a 3 day grace period before a user counts. + """ + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + return await self.db_pool.runInteraction( + "count_monthly_users", self._count_users, thirty_days_ago + ) + + def _count_users(self, txn, time_from): + """ + Returns number of users seen in the past time_from period + """ + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + txn.execute(sql, (time_from,)) + (count,) = txn.fetchone() + return count + + async def count_r30_users(self) -> Dict[str, int]: + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns: + A mapping of counts globally as well as broken out by platform. + """ + + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + for row in txn: + if row[0] == "unknown": + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + (count,) = txn.fetchone() + results["all"] = count + + return results + + return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) + + def _get_start_of_day(self): + """ + Returns millisecond unixtime for start of UTC day. + """ + now = time.gmtime() + today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) + return today_start * 1000 + + @wrap_as_background_process("generate_user_daily_visits") + async def generate_user_daily_visits(self) -> None: + """ + Generates daily visit data for use in cohort/ retention analysis + """ + + def _generate_user_daily_visits(txn): + logger.info("Calling _generate_user_daily_visits") + today_start = self._get_start_of_day() + a_day_in_milliseconds = 24 * 60 * 60 * 1000 + now = self._clock.time_msec() + + # A note on user_agent. Technically a given device can have multiple + # user agents, so we need to decide which one to pick. We could have + # handled this in number of ways, but given that we don't care + # _that_ much we have gone for MAX(). For more details of the other + # options considered see + # https://github.com/matrix-org/synapse/pull/8503#discussion_r502306111 + sql = """ + INSERT INTO user_daily_visits (user_id, device_id, timestamp, user_agent) + SELECT u.user_id, u.device_id, ?, MAX(u.user_agent) + FROM user_ips AS u + LEFT JOIN ( + SELECT user_id, device_id, timestamp FROM user_daily_visits + WHERE timestamp = ? + ) udv + ON u.user_id = udv.user_id AND u.device_id=udv.device_id + INNER JOIN users ON users.name=u.user_id + WHERE last_seen > ? AND last_seen <= ? + AND udv.timestamp IS NULL AND users.is_guest=0 + AND users.appservice_id IS NULL + GROUP BY u.user_id, u.device_id + """ + + # This means that the day has rolled over but there could still + # be entries from the previous day. There is an edge case + # where if the user logs in at 23:59 and overwrites their + # last_seen at 00:01 then they will not be counted in the + # previous day's stats - it is important that the query is run + # often to minimise this case. + if today_start > self._last_user_visit_update: + yesterday_start = today_start - a_day_in_milliseconds + txn.execute( + sql, + ( + yesterday_start, + yesterday_start, + self._last_user_visit_update, + today_start, + ), + ) + self._last_user_visit_update = today_start + + txn.execute( + sql, (today_start, today_start, self._last_user_visit_update, now) + ) + # Update _last_user_visit_update to now. The reason to do this + # rather just clamping to the beginning of the day is to limit + # the size of the join - meaning that the query can be run more + # frequently + self._last_user_visit_update = now + + await self.db_pool.runInteraction( + "generate_user_daily_visits", _generate_user_daily_visits + ) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e93aad33cd..d788dc0fc6 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,6 +15,7 @@ import logging from typing import Dict, List +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached @@ -32,6 +33,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._max_mau_value = hs.config.max_mau_value + @cached(num_args=0) async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users @@ -124,60 +128,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): desc="user_last_seen_monthly_active", ) - -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._mau_stats_only = hs.config.mau_stats_only - self._max_mau_value = hs.config.max_mau_value - - # Do not add more reserved users than the total allowable number - # cur = LoggingTransaction( - self.db_pool.new_transaction( - db_conn, - "initialise_mau_threepids", - [], - [], - self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self._max_mau_value], - ) - - def _initialise_reserved_users(self, txn, threepids): - """Ensures that reserved threepids are accounted for in the MAU table, should - be called on start up. - - Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve - """ - - # XXX what is this function trying to achieve? It upserts into - # monthly_active_users for each *registered* reserved mau user, but why? - # - # - shouldn't there already be an entry for each reserved user (at least - # if they have been active recently)? - # - # - if it's important that the timestamp is kept up to date, why do we only - # run this at startup? - - for tp in threepids: - user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) - - if user_id: - is_support = self.is_support_user_txn(txn, user_id) - if not is_support: - # We do this manually here to avoid hitting #6791 - self.db_pool.simple_upsert_txn( - txn, - table="monthly_active_users", - keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, - ) - else: - logger.warning("mau limit reserved threepid %s not found in db" % tp) - + @wrap_as_background_process("reap_monthly_active_users") async def reap_monthly_active_users(self): """Cleans out monthly active user table to ensure that no stale entries exist. @@ -257,6 +208,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "reap_monthly_active_users", _reap_users, reserved_users ) + +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._mau_stats_only = hs.config.mau_stats_only + + # Do not add more reserved users than the total allowable number + self.db_pool.new_transaction( + db_conn, + "initialise_mau_threepids", + [], + [], + self._initialise_reserved_users, + hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + ) + + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ + + # XXX what is this function trying to achieve? It upserts into + # monthly_active_users for each *registered* reserved mau user, but why? + # + # - shouldn't there already be an entry for each reserved user (at least + # if they have been active recently)? + # + # - if it's important that the timestamp is kept up to date, why do we only + # run this at startup? + + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) + + if user_id: + is_support = self.is_support_user_txn(txn, user_id) + if not is_support: + # We do this manually here to avoid hitting #6791 + self.db_pool.simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) + else: + logger.warning("mau limit reserved threepid %s not found in db" % tp) + async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index d2e0685e9e..0e25ca3d7a 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore @@ -39,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - async def get_profile_displayname(self, user_localpart: str) -> str: + async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, @@ -47,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_displayname", ) - async def get_profile_avatar_url(self, user_localpart: str) -> str: + async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, @@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_displayname( - self, user_localpart: str, new_displayname: str + self, user_localpart: str, new_displayname: Optional[str] ) -> None: await self.db_pool.simple_update_one( table="profiles", @@ -91,27 +91,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="set_profile_avatar_url", ) - -class ProfileStore(ProfileWorkerStore): - async def add_remote_profile_cache( - self, user_id: str, displayname: str, avatar_url: str - ) -> None: - """Ensure we are caching the remote user's profiles. - - This should only be called when `is_subscribed_remote_profile_for_user` - would return true for the user. - """ - await self.db_pool.simple_upsert( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - values={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="add_remote_profile_cache", - ) - async def update_remote_profile_cache( self, user_id: str, displayname: str, avatar_url: str ) -> int: @@ -138,9 +117,34 @@ class ProfileStore(ProfileWorkerStore): desc="delete_remote_profile_cache", ) + async def is_subscribed_remote_profile_for_user(self, user_id): + """Check whether we are interested in a remote user's profile. + """ + res = await self.db_pool.simple_select_one_onecol( + table="group_users", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + return True + + res = await self.db_pool.simple_select_one_onecol( + table="group_invites", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + return True + async def get_remote_profile_cache_entries_that_expire( self, last_checked: int - ) -> Dict[str, str]: + ) -> List[Dict[str, str]]: """Get all users who haven't been checked since `last_checked` """ @@ -160,27 +164,23 @@ class ProfileStore(ProfileWorkerStore): _get_remote_profile_cache_entries_that_expire_txn, ) - async def is_subscribed_remote_profile_for_user(self, user_id): - """Check whether we are interested in a remote user's profile. - """ - res = await self.db_pool.simple_select_one_onecol( - table="group_users", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - if res: - return True +class ProfileStore(ProfileWorkerStore): + async def add_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> None: + """Ensure we are caching the remote user's profiles. - res = await self.db_pool.simple_select_one_onecol( - table="group_invites", + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + await self.db_pool.simple_upsert( + table="remote_profile_cache", keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", ) - - if res: - return True diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ecfc6717b3..5d668aadb2 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): for table in ( "event_auth", "event_edges", + "event_json", "event_push_actions_staging", "event_reference_hashes", "event_relations", @@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "destination_rooms", "event_backward_extremities", "event_forward_extremities", - "event_json", "event_push_actions", "event_search", "events", diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df8609b97b..77ba9d819e 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -15,18 +15,32 @@ # limitations under the License. import logging -from typing import Iterable, Iterator, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple from canonicaljson import encode_canonical_json +from synapse.push import PusherConfig, ThrottleParams from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.storage.types import Connection +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class PusherWorkerStore(SQLBaseStore): - def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + self._pushers_id_gen = StreamIdGenerator( + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] + ) + + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: """JSON-decode the data in the rows returned from the `pushers` table Drops any rows whose data cannot be decoded @@ -44,21 +58,23 @@ class PusherWorkerStore(SQLBaseStore): ) continue - yield r + yield PusherConfig(**r) - async def user_has_pusher(self, user_id): + async def user_has_pusher(self, user_id: str) -> bool: ret = await self.db_pool.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None - def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): - return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) + async def get_pushers_by_app_id_and_pushkey( + self, app_id: str, pushkey: str + ) -> Iterator[PusherConfig]: + return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) - def get_pushers_by_user_id(self, user_id): - return self.get_pushers_by({"user_name": user_id}) + async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]: + return await self.get_pushers_by({"user_name": user_id}) - async def get_pushers_by(self, keyvalues): + async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]: ret = await self.db_pool.simple_select_list( "pushers", keyvalues, @@ -83,7 +99,7 @@ class PusherWorkerStore(SQLBaseStore): ) return self._decode_pushers_rows(ret) - async def get_all_pushers(self): + async def get_all_pushers(self) -> Iterator[PusherConfig]: def get_pushers(txn): txn.execute("SELECT * FROM pushers") rows = self.db_pool.cursor_to_dict(txn) @@ -159,14 +175,16 @@ class PusherWorkerStore(SQLBaseStore): ) @cached(num_args=1, max_entries=15000) - async def get_if_user_has_pusher(self, user_id): + async def get_if_user_has_pusher(self, user_id: str): # This only exists for the cachedList decorator raise NotImplementedError() @cachedList( cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1, ) - async def get_if_users_have_pushers(self, user_ids): + async def get_if_users_have_pushers( + self, user_ids: Iterable[str] + ) -> Dict[str, bool]: rows = await self.db_pool.simple_select_many_batch( table="pushers", column="user_name", @@ -224,7 +242,7 @@ class PusherWorkerStore(SQLBaseStore): return bool(updated) async def update_pusher_failing_since( - self, app_id, pushkey, user_id, failing_since + self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int] ) -> None: await self.db_pool.simple_update( table="pushers", @@ -233,7 +251,9 @@ class PusherWorkerStore(SQLBaseStore): desc="update_pusher_failing_since", ) - async def get_throttle_params_by_room(self, pusher_id): + async def get_throttle_params_by_room( + self, pusher_id: str + ) -> Dict[str, ThrottleParams]: res = await self.db_pool.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, @@ -243,43 +263,44 @@ class PusherWorkerStore(SQLBaseStore): params_by_room = {} for row in res: - params_by_room[row["room_id"]] = { - "last_sent_ts": row["last_sent_ts"], - "throttle_ms": row["throttle_ms"], - } + params_by_room[row["room_id"]] = ThrottleParams( + row["last_sent_ts"], row["throttle_ms"], + ) return params_by_room - async def set_throttle_params(self, pusher_id, room_id, params) -> None: + async def set_throttle_params( + self, pusher_id: str, room_id: str, params: ThrottleParams + ) -> None: # no need to lock because `pusher_throttle` has a primary key on # (pusher, room_id) so simple_upsert will retry await self.db_pool.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, - params, + {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms}, desc="set_throttle_params", lock=False, ) class PusherStore(PusherWorkerStore): - def get_pushers_stream_token(self): + def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() async def add_pusher( self, - user_id, - access_token, - kind, - app_id, - app_display_name, - device_display_name, - pushkey, - pushkey_ts, - lang, - data, - last_stream_ordering, - profile_tag="", + user_id: str, + access_token: Optional[int], + kind: str, + app_id: str, + app_display_name: str, + device_display_name: str, + pushkey: str, + pushkey_ts: int, + lang: Optional[str], + data: Optional[JsonDict], + last_stream_ordering: int, + profile_tag: str = "", ) -> None: async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on @@ -303,7 +324,7 @@ class PusherStore(PusherWorkerStore): lock=False, ) - user_has_pusher = self.get_if_user_has_pusher.cache.get( + user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate( (user_id,), None, update_metrics=False ) @@ -311,16 +332,16 @@ class PusherStore(PusherWorkerStore): # invalidate, since we the user might not have had a pusher before await self.db_pool.runInteraction( "add_pusher", - self._invalidate_cache_and_stream, + self._invalidate_cache_and_stream, # type: ignore self.get_if_user_has_pusher, (user_id,), ) async def delete_pusher_by_app_id_pushkey_user_id( - self, app_id, pushkey, user_id + self, app_id: str, pushkey: str, user_id: str ) -> None: def delete_pusher_txn(txn, stream_id): - self._invalidate_cache_and_stream( + self._invalidate_cache_and_stream( # type: ignore txn, self.get_if_user_has_pusher, (user_id,) ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c79ddff680..1e7949a323 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -23,8 +23,8 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -274,6 +274,65 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): } return results + @cached(num_args=2,) + async def get_linearized_receipts_for_all_rooms( + self, to_key: int, from_key: Optional[int] = None + ) -> Dict[str, JsonDict]: + """Get receipts for all rooms between two stream_ids, up + to a limit of the latest 100 read receipts. + + Args: + to_key: Max stream id to fetch receipts upto. + from_key: Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + A dictionary of roomids to a list of receipts. + """ + + def f(txn): + if from_key: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id > ? AND stream_id <= ? + ORDER BY stream_id DESC + LIMIT 100 + """ + txn.execute(sql, [from_key, to_key]) + else: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id <= ? + ORDER BY stream_id DESC + LIMIT 100 + """ + + txn.execute(sql, [to_key]) + + return self.db_pool.cursor_to_dict(txn) + + txn_results = await self.db_pool.runInteraction( + "get_linearized_receipts_for_all_rooms", f + ) + + results = {} + for row in txn_results: + # We want a single event per room, since we want to batch the + # receipts by room, event and type. + room_event = results.setdefault( + row["room_id"], + {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + ) + + # The content is of the form: + # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } + event_entry = room_event["content"].setdefault(row["event_id"], {}) + receipt_type = event_entry.setdefault(row["receipt_type"], {}) + + receipt_type[row["user_id"]] = db_to_json(row["data"]) + + return results + async def get_users_sent_receipts_between( self, last_id: int, current_id: int ) -> List[str]: @@ -358,18 +417,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): if receipt_type != "m.read": return - # Returns either an ObservableDeferred or the raw result - res = self.get_users_with_read_receipts_in_room.cache.get( + res = self.get_users_with_read_receipts_in_room.cache.get_immediate( room_id, None, update_metrics=False ) - # first handle the ObservableDeferred case - if isinstance(res, ObservableDeferred): - if res.has_called(): - res = res.get_result() - else: - res = None - if res and user_id in res: # We'd only be adding to the set, so no point invalidating if the # user is already there diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a83df7759d..8d05288ed4 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,32 +14,66 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import re -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import attr from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import DatabasePool -from synapse.storage.types import Cursor +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.types import Connection, Cursor +from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 logger = logging.getLogger(__name__) -class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): +@attr.s(frozen=True, slots=True) +class TokenLookupResult: + """Result of looking up an access token. + + Attributes: + user_id: The user that this token authenticates as + is_guest + shadow_banned + token_id: The ID of the access token looked up + device_id: The device associated with the token, if any. + valid_until_ms: The timestamp the token expires, if any. + token_owner: The "owner" of the token. This is either the same as the + user, or a server admin who is logged in as the user. + """ + + user_id = attr.ib(type=str) + is_guest = attr.ib(type=bool, default=False) + shadow_banned = attr.ib(type=bool, default=False) + token_id = attr.ib(type=Optional[int], default=None) + device_id = attr.ib(type=Optional[str], default=None) + valid_until_ms = attr.ib(type=Optional[int], default=None) + token_owner = attr.ib(type=str) + + # Make the token owner default to the user ID, which is the common case. + @token_owner.default + def _default_token_owner(self): + return self.user_id + + +class RegistrationWorkerStore(CacheInvalidationWorkerStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config - self.clock = hs.get_clock() # Note: we don't check this sequence for consistency as we'd have to # call `find_max_generated_user_id_localpart` each time, which is @@ -48,6 +82,18 @@ class RegistrationWorkerStore(SQLBaseStore): database.engine, find_max_generated_user_id_localpart, "user_id_seq", ) + self._account_validity = hs.config.account_validity + if hs.config.run_background_tasks and self._account_validity.enabled: + self._clock.call_later( + 0.0, self._set_expiration_date_when_missing, + ) + + # Create a background job for culling expired 3PID validity tokens + if hs.config.run_background_tasks: + self._clock.looping_call( + self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS + ) + @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: return await self.db_pool.simple_select_one( @@ -81,21 +127,19 @@ class RegistrationWorkerStore(SQLBaseStore): if not info: return False - now = self.clock.time_msec() + now = self._clock.time_msec() trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial @cached() - async def get_user_by_access_token(self, token: str) -> Optional[dict]: + async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]: """Get a user from the given access token. Args: token: The access token of a user. Returns: - None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`, - `valid_until_ms`. + None, if the token did not match, otherwise a `TokenLookupResult` """ return await self.db_pool.runInteraction( "get_user_by_access_token", self._query_for_auth, token @@ -225,13 +269,13 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_renewal_token_for_user", ) - async def get_users_expiring_soon(self) -> List[Dict[str, int]]: + async def get_users_expiring_soon(self) -> List[Dict[str, Any]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). Returns: - A list of dictionaries mapping user ID to expiration time (in milliseconds). + A list of dictionaries, each with a user ID and expiration time (in milliseconds). """ def select_users_txn(txn, now_ms, renew_at): @@ -246,7 +290,7 @@ class RegistrationWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, - self.clock.time_msec(), + self._clock.time_msec(), self.config.account_validity.renew_at, ) @@ -316,19 +360,24 @@ class RegistrationWorkerStore(SQLBaseStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) - def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," - " access_tokens.device_id, access_tokens.valid_until_ms" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: + sql = """ + SELECT users.name as user_id, + users.is_guest, + users.shadow_banned, + access_tokens.id as token_id, + access_tokens.device_id, + access_tokens.valid_until_ms, + access_tokens.user_id as token_owner + FROM users + INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id) + WHERE token = ? + """ txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) if rows: - return rows[0] + return TokenLookupResult(**rows[0]) return None @@ -414,6 +463,23 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_user_by_external_id", ) + async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]: + """Look up external ids for the given user + + Args: + mxid: the MXID to be looked up + + Returns: + Tuples of (auth_provider, external_id) + """ + res = await self.db_pool.simple_select_list( + table="user_external_ids", + keyvalues={"user_id": mxid}, + retcols=("auth_provider", "external_id"), + desc="get_external_ids_by_user", + ) + return [(r["auth_provider"], r["external_id"]) for r in res] + async def count_all_users(self): """Counts all users registered on the homeserver.""" @@ -778,12 +844,147 @@ class RegistrationWorkerStore(SQLBaseStore): "delete_threepid_session", delete_threepid_session_txn ) + @wrap_as_background_process("cull_expired_threepid_validation_tokens") + async def cull_expired_threepid_validation_tokens(self) -> None: + """Remove threepid validation tokens with expiry dates that have passed""" + + def cull_expired_threepid_validation_tokens_txn(txn, ts): + sql = """ + DELETE FROM threepid_validation_token WHERE + expires < ? + """ + txn.execute(sql, (ts,)) + + await self.db_pool.runInteraction( + "cull_expired_threepid_validation_tokens", + cull_expired_threepid_validation_tokens_txn, + self._clock.time_msec(), + ) + + @wrap_as_background_process("account_validity_set_expiration_dates") + async def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database, filtering out deactivated users. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" + ) + txn.execute(sql, []) + + res = self.db_pool.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, user["name"], use_delta=True + ) + + await self.db_pool.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self.db_pool.simple_upsert_txn( + txn, + "account_validity", + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, + ) + + async def get_user_pending_deactivation(self) -> Optional[str]: + """ + Gets one user from the table of users waiting to be parted from all the rooms + they're in. + """ + return await self.db_pool.simple_select_one_onecol( + "users_pending_deactivation", + keyvalues={}, + retcol="user_id", + allow_none=True, + desc="get_users_pending_deactivation", + ) + + async def del_user_pending_deactivation(self, user_id: str) -> None: + """ + Removes the given user to the table of users who need to be parted from all the + rooms they're in, effectively marking that user as fully deactivated. + """ + # XXX: This should be simple_delete_one but we failed to put a unique index on + # the table, so somehow duplicate entries have ended up in it. + await self.db_pool.simple_delete( + "users_pending_deactivation", + keyvalues={"user_id": user_id}, + desc="del_user_pending_deactivation", + ) + + async def get_access_token_last_validated(self, token_id: int) -> int: + """Retrieves the time (in milliseconds) of the last validation of an access token. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if the access token was not found. + + Returns: + The last validation time. + """ + result = await self.db_pool.simple_select_one_onecol( + "access_tokens", {"id": token_id}, "last_validated" + ) + + # If this token has not been validated (since starting to track this), + # return 0 instead of None. + return result or 0 + + async def update_access_token_last_validated(self, token_id: int) -> None: + """Updates the last time an access token was validated. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if there was a problem updating this. + """ + now = self._clock.time_msec() + + await self.db_pool.simple_update_one( + "access_tokens", + {"id": token_id}, + {"last_validated": now}, + desc="update_access_token_last_validated", + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self.clock = hs.get_clock() + self._clock = hs.get_clock() self.config = hs.config self.db_pool.updates.register_background_index_update( @@ -815,6 +1016,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) + self.db_pool.updates.register_background_index_update( + "user_external_ids_user_id_idx", + index_name="user_external_ids_user_id_idx", + table="user_external_ids", + columns=["user_id"], + unique=False, + ) + async def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 for each of them. @@ -906,32 +1115,55 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): return 1 + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: + """Set the `deactivated` property for the provided user to the provided value. -class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) + Args: + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. + """ - self._account_validity = hs.config.account_validity - self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + await self.db_pool.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) - if self._account_validity.enabled: - self._clock.call_later( - 0.0, - run_as_background_process, - "account_validity_set_expiration_dates", - self._set_expiration_date_when_missing, - ) + def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool): + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) + txn.call_after(self.is_guest.invalidate, (user_id,)) - # Create a background job for culling expired 3PID validity tokens - def start_cull(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "cull_expired_threepid_validation_tokens", - self.cull_expired_threepid_validation_tokens, - ) + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="is_guest", + allow_none=True, + desc="is_guest", + ) + + return res if res else False - hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) + +class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + + self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") async def add_access_token_to_user( self, @@ -939,7 +1171,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): token: str, device_id: Optional[str], valid_until_ms: Optional[int], - ) -> None: + puppets_user_id: Optional[str] = None, + ) -> int: """Adds an access token for the given user. Args: @@ -949,8 +1182,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): valid_until_ms: when the token is valid until. None for no expiry. Raises: StoreError if there was a problem adding this. + Returns: + The token ID """ next_id = self._access_tokens_id_gen.get_next() + now = self._clock.time_msec() await self.db_pool.simple_insert( "access_tokens", @@ -960,10 +1196,44 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "token": token, "device_id": device_id, "valid_until_ms": valid_until_ms, + "puppets_user_id": puppets_user_id, + "last_validated": now, }, desc="add_access_token_to_user", ) + return next_id + + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, "access_tokens", {"token": token}, "device_id" + ) + + self.db_pool.simple_update_txn( + txn, "access_tokens", {"token": token}, {"device_id": device_id} + ) + + self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,)) + + return old_device_id + + async def set_device_for_access_token(self, token: str, device_id: str) -> str: + """Sets the device ID associated with an access token. + + Args: + token: The access token to modify. + device_id: The new device ID. + Returns: + The old device ID associated with the access token. + """ + + return await self.db_pool.runInteraction( + "set_device_for_access_token", + self._set_device_for_access_token_txn, + token, + device_id, + ) + async def register_user( self, user_id: str, @@ -1014,19 +1284,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def _register_user( self, txn, - user_id, - password_hash, - was_guest, - make_guest, - appservice_id, - create_profile_with_displayname, - admin, - user_type, - shadow_banned, + user_id: str, + password_hash: Optional[str], + was_guest: bool, + make_guest: bool, + appservice_id: Optional[str], + create_profile_with_displayname: Optional[str], + admin: bool, + user_type: Optional[str], + shadow_banned: bool, ): user_id_obj = UserID.from_string(user_id) - now = int(self.clock.time()) + now = int(self._clock.time()) try: if was_guest: @@ -1121,7 +1391,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="record_user_external_id", ) - async def user_set_password_hash(self, user_id: str, password_hash: str) -> None: + async def user_set_password_hash( + self, user_id: str, password_hash: Optional[str] + ) -> None: """ NB. This does *not* evict any cache because the one use for this removes most of the entries subsequently anyway so it would be @@ -1248,18 +1520,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): await self.db_pool.runInteraction("delete_access_token", f) - @cached() - async def is_guest(self, user_id: str) -> bool: - res = await self.db_pool.simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="is_guest", - allow_none=True, - desc="is_guest", - ) - - return res if res else False - async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're @@ -1271,32 +1531,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_user_pending_deactivation", ) - async def del_user_pending_deactivation(self, user_id: str) -> None: - """ - Removes the given user to the table of users who need to be parted from all the - rooms they're in, effectively marking that user as fully deactivated. - """ - # XXX: This should be simple_delete_one but we failed to put a unique index on - # the table, so somehow duplicate entries have ended up in it. - await self.db_pool.simple_delete( - "users_pending_deactivation", - keyvalues={"user_id": user_id}, - desc="del_user_pending_deactivation", - ) - - async def get_user_pending_deactivation(self) -> Optional[str]: - """ - Gets one user from the table of users waiting to be parted from all the rooms - they're in. - """ - return await self.db_pool.simple_select_one_onecol( - "users_pending_deactivation", - keyvalues={}, - retcol="user_id", - allow_none=True, - desc="get_users_pending_deactivation", - ) - async def validate_threepid_session( self, session_id: str, client_secret: str, token: str, current_ts: int ) -> Optional[str]: @@ -1379,7 +1613,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self.clock.time_msec()}, + updatevalues={"validated_at": self._clock.time_msec()}, ) return next_link @@ -1447,106 +1681,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) - async def cull_expired_threepid_validation_tokens(self) -> None: - """Remove threepid validation tokens with expiry dates that have passed""" - - def cull_expired_threepid_validation_tokens_txn(txn, ts): - sql = """ - DELETE FROM threepid_validation_token WHERE - expires < ? - """ - txn.execute(sql, (ts,)) - - await self.db_pool.runInteraction( - "cull_expired_threepid_validation_tokens", - cull_expired_threepid_validation_tokens_txn, - self.clock.time_msec(), - ) - - async def set_user_deactivated_status( - self, user_id: str, deactivated: bool - ) -> None: - """Set the `deactivated` property for the provided user to the provided value. - - Args: - user_id: The ID of the user to set the status for. - deactivated: The value to set for `deactivated`. - """ - - await self.db_pool.runInteraction( - "set_user_deactivated_status", - self.set_user_deactivated_status_txn, - user_id, - deactivated, - ) - - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"deactivated": 1 if deactivated else 0}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,) - ) - txn.call_after(self.is_guest.invalidate, (user_id,)) - - async def _set_expiration_date_when_missing(self): - """ - Retrieves the list of registered users that don't have an expiration date, and - adds an expiration date for each of them. - """ - - def select_users_with_no_expiration_date_txn(txn): - """Retrieves the list of registered users with no expiration date from the - database, filtering out deactivated users. - """ - sql = ( - "SELECT users.name FROM users" - " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" - ) - txn.execute(sql, []) - - res = self.db_pool.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) - - await self.db_pool.runInteraction( - "get_users_with_no_expiration_date", - select_users_with_no_expiration_date_txn, - ) - - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): - """Sets an expiration date to the account with the given user ID. - - Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be - now + validity period. If set to True, this expiration date will be a - random value in the [now + period - d ; now + period] range, d being a - delta equal to 10% of the validity period. - """ - now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - - if use_delta: - expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, - expiration_ts, - ) - - self.db_pool.simple_upsert_txn( - txn, - "account_validity", - keyvalues={"user_id": user_id}, - values={"expiration_ts_ms": expiration_ts, "email_sent": False}, - ) - def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3c7630857f..4650d0689b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore): "count_public_rooms", _count_public_rooms_txn ) + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. + """ + + def f(txn): + sql = "SELECT count(*) FROM rooms" + txn.execute(sql) + row = txn.fetchone() + return row[0] or 0 + + return await self.db_pool.runInteraction("get_rooms", f) + async def get_largest_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], @@ -367,14 +379,14 @@ class RoomWorkerStore(SQLBaseStore): # Filter room names by a string where_statement = "" if search_term: - where_statement = "WHERE state.name LIKE ?" + where_statement = "WHERE LOWER(state.name) LIKE ?" # Our postgres db driver converts ? -> %s in SQL strings as that's the # placeholder for postgres. # HOWEVER, if you put a % into your SQL then everything goes wibbly. # To get around this, we're going to surround search_term with %'s # before giving it to the database in python instead - search_term = "%" + search_term + "%" + search_term = "%" + search_term.lower() + "%" # Set ordering if RoomSortOrder(order_by) == RoomSortOrder.SIZE: @@ -857,6 +869,89 @@ class RoomWorkerStore(SQLBaseStore): "get_all_new_public_rooms", get_all_new_public_rooms ) + async def get_rooms_for_retention_period_in_range( + self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False + ) -> Dict[str, dict]: + """Retrieves all of the rooms within the given retention range. + + Optionally includes the rooms which don't have a retention policy. + + Args: + min_ms: Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, doesn't set a lower limit. + max_ms: Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, doesn't set an upper limit. + include_null: Whether to include rooms which retention policy is NULL + in the returned set. + + Returns: + The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). + """ + + def get_rooms_for_retention_period_in_range_txn(txn): + range_conditions = [] + args = [] + + if min_ms is not None: + range_conditions.append("max_lifetime > ?") + args.append(min_ms) + + if max_ms is not None: + range_conditions.append("max_lifetime <= ?") + args.append(max_ms) + + # Do a first query which will retrieve the rooms that have a retention policy + # in their current state. + sql = """ + SELECT room_id, min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + """ + + if len(range_conditions): + sql += " WHERE (" + " AND ".join(range_conditions) + ")" + + if include_null: + sql += " OR max_lifetime IS NULL" + + txn.execute(sql, args) + + rows = self.db_pool.cursor_to_dict(txn) + rooms_dict = {} + + for row in rows: + rooms_dict[row["room_id"]] = { + "min_lifetime": row["min_lifetime"], + "max_lifetime": row["max_lifetime"], + } + + if include_null: + # If required, do a second query that retrieves all of the rooms we know + # of so we can handle rooms with no retention policy. + sql = "SELECT DISTINCT room_id FROM current_state_events" + + txn.execute(sql) + + rows = self.db_pool.cursor_to_dict(txn) + + # If a room isn't already in the dict (i.e. it doesn't have a retention + # policy in its state), add it with a null policy. + for row in rows: + if row["room_id"] not in rooms_dict: + rooms_dict[row["room_id"]] = { + "min_lifetime": None, + "max_lifetime": None, + } + + return rooms_dict + + return await self.db_pool.runInteraction( + "get_rooms_for_retention_period_in_range", + get_rooms_for_retention_period_in_range_txn, + ) + class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -1145,13 +1240,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion): + async def maybe_store_room_on_outlier_membership( + self, room_id: str, room_version: RoomVersion + ): """ - When we receive an invite over federation, store the version of the room if we - don't already know the room version. + When we receive an invite or any other event over federation that may relate to a room + we are not in, store the version of the room if we don't already know the room version. """ await self.db_pool.simple_upsert( - desc="maybe_store_room_on_invite", + desc="maybe_store_room_on_outlier_membership", table="rooms", keyvalues={"room_id": room_id}, values={}, @@ -1292,18 +1389,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - async def get_room_count(self) -> int: - """Retrieve the total number of rooms. - """ - - def f(txn): - sql = "SELECT count(*) FROM rooms" - txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 - - return await self.db_pool.runInteraction("get_rooms", f) - async def add_event_report( self, room_id: str, @@ -1328,6 +1413,65 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): desc="add_event_report", ) + async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: + """Retrieve an event report + + Args: + report_id: ID of reported event in database + Returns: + event_report: json list of information from event report + """ + + def _get_event_report_txn(txn, report_id): + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.content, + events.sender, + room_stats_state.canonical_alias, + room_stats_state.name, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id + WHERE er.id = ? + """ + + txn.execute(sql, [report_id]) + row = txn.fetchone() + + if not row: + return None + + event_report = { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": db_to_json(row[5]).get("score"), + "reason": db_to_json(row[5]).get("reason"), + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + "event_json": db_to_json(row[9]), + } + + return event_report + + return await self.db_pool.runInteraction( + "get_event_report", _get_event_report_txn, report_id + ) + async def get_event_reports_paginate( self, start: int, @@ -1385,18 +1529,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): er.room_id, er.event_id, er.user_id, - er.reason, er.content, events.sender, - room_aliases.room_alias, - event_json.json AS event_json + room_stats_state.canonical_alias, + room_stats_state.name FROM event_reports AS er - LEFT JOIN room_aliases - ON room_aliases.room_id = er.room_id - JOIN events + LEFT JOIN events ON events.event_id = er.event_id - JOIN event_json - ON event_json.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id {where_clause} ORDER BY er.received_ts {order} LIMIT ? @@ -1407,15 +1548,29 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): args += [limit, start] txn.execute(sql, args) - event_reports = self.db_pool.cursor_to_dict(txn) - - if count > 0: - for row in event_reports: - try: - row["content"] = db_to_json(row["content"]) - row["event_json"] = db_to_json(row["event_json"]) - except Exception: - continue + + event_reports = [] + for row in txn: + try: + s = db_to_json(row[5]).get("score") + r = db_to_json(row[5]).get("reason") + except Exception: + logger.error("Unable to parse json from event_reports: %s", row[0]) + continue + event_reports.append( + { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": s, + "reason": r, + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + } + ) return event_reports, count @@ -1446,88 +1601,3 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.is_room_blocked, (room_id,), ) - - async def get_rooms_for_retention_period_in_range( - self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False - ) -> Dict[str, dict]: - """Retrieves all of the rooms within the given retention range. - - Optionally includes the rooms which don't have a retention policy. - - Args: - min_ms: Duration in milliseconds that define the lower limit of - the range to handle (exclusive). If None, doesn't set a lower limit. - max_ms: Duration in milliseconds that define the upper limit of - the range to handle (inclusive). If None, doesn't set an upper limit. - include_null: Whether to include rooms which retention policy is NULL - in the returned set. - - Returns: - The rooms within this range, along with their retention - policy. The key is "room_id", and maps to a dict describing the retention - policy associated with this room ID. The keys for this nested dict are - "min_lifetime" (int|None), and "max_lifetime" (int|None). - """ - - def get_rooms_for_retention_period_in_range_txn(txn): - range_conditions = [] - args = [] - - if min_ms is not None: - range_conditions.append("max_lifetime > ?") - args.append(min_ms) - - if max_ms is not None: - range_conditions.append("max_lifetime <= ?") - args.append(max_ms) - - # Do a first query which will retrieve the rooms that have a retention policy - # in their current state. - sql = """ - SELECT room_id, min_lifetime, max_lifetime FROM room_retention - INNER JOIN current_state_events USING (event_id, room_id) - """ - - if len(range_conditions): - sql += " WHERE (" + " AND ".join(range_conditions) + ")" - - if include_null: - sql += " OR max_lifetime IS NULL" - - txn.execute(sql, args) - - rows = self.db_pool.cursor_to_dict(txn) - rooms_dict = {} - - for row in rows: - rooms_dict[row["room_id"]] = { - "min_lifetime": row["min_lifetime"], - "max_lifetime": row["max_lifetime"], - } - - if include_null: - # If required, do a second query that retrieves all of the rooms we know - # of so we can handle rooms with no retention policy. - sql = "SELECT DISTINCT room_id FROM current_state_events" - - txn.execute(sql) - - rows = self.db_pool.cursor_to_dict(txn) - - # If a room isn't already in the dict (i.e. it doesn't have a retention - # policy in its state), add it with a null policy. - for row in rows: - if row["room_id"] not in rooms_dict: - rooms_dict[row["room_id"]] = { - "min_lifetime": None, - "max_lifetime": None, - } - - return rooms_dict - - rooms = await self.db_pool.runInteraction( - "get_rooms_for_retention_period_in_range", - get_rooms_for_retention_period_in_range_txn, - ) - - return rooms diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 86ffe2479e..dcdaf09682 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -14,19 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - LoggingTransaction, - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, ) +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine @@ -60,27 +58,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): # background update still running? self._current_state_events_membership_up_to_date = False - txn = LoggingTransaction( - db_conn.cursor(), - name="_check_safe_current_state_events_membership_updated", - database_engine=self.database_engine, + txn = db_conn.cursor( + txn_name="_check_safe_current_state_events_membership_updated" ) self._check_safe_current_state_events_membership_updated_txn(txn) txn.close() - if self.hs.config.metrics_flags.known_servers: + if ( + self.hs.config.run_background_tasks + and self.hs.config.metrics_flags.known_servers + ): self._known_servers_count = 1 self.hs.get_clock().looping_call( - run_as_background_process, - 60 * 1000, - "_count_known_servers", - self._count_known_servers, + self._count_known_servers, 60 * 1000, ) self.hs.get_clock().call_later( - 1000, - run_as_background_process, - "_count_known_servers", - self._count_known_servers, + 1000, self._count_known_servers, ) LaterGauge( "synapse_federation_known_servers", @@ -89,6 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): lambda: self._known_servers_count, ) + @wrap_as_background_process("_count_known_servers") async def _count_known_servers(self): """ Count the servers that this server knows about. @@ -356,6 +350,38 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results + async def get_local_current_membership_for_user_in_room( + self, user_id: str, room_id: str + ) -> Tuple[Optional[str], Optional[str]]: + """Retrieve the current local membership state and event ID for a user in a room. + + Args: + user_id: The ID of the user. + room_id: The ID of the room. + + Returns: + A tuple of (membership_type, event_id). Both will be None if a + room_id/user_id pair is not found. + """ + # Paranoia check. + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'get_local_current_membership_for_user_in_room' on " + "non-local user %s" % (user_id,), + ) + + results_dict = await self.db_pool.simple_select_one( + "local_current_membership", + {"room_id": room_id, "user_id": user_id}, + ("membership", "event_id"), + allow_none=True, + desc="get_local_current_membership_for_user_in_room", + ) + if not results_dict: + return None, None + + return results_dict.get("membership"), results_dict.get("event_id") + @cached(max_entries=500000, iterable=True) async def get_rooms_for_user_with_stream_ordering( self, user_id: str @@ -535,7 +561,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # If we do then we can reuse that result and simply update it with # any membership changes in `delta_ids` if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get( + prev_res = self._get_joined_users_from_context.cache.get_immediate( (room_id, context.prev_group), None ) if prev_res and isinstance(prev_res, dict): diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py
index 3edfcfd783..45b846e6a7 100644 --- a/synapse/storage/databases/main/schema/delta/20/pushers.py +++ b/synapse/storage/databases/main/schema/delta/20/pushers.py
@@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs): row[8] = bytes(row[8]).decode("utf-8") row[11] = bytes(row[11]).decode("utf-8") cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py
index ee675e71ff..21f57825d4 100644 --- a/synapse/storage/databases/main/schema/delta/25/fts.py +++ b/synapse/storage/databases/main/schema/delta/25/fts.py
@@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py
index b7972cfa8e..1c6058063f 100644 --- a/synapse/storage/databases/main/schema/delta/27/ts.py +++ b/synapse/storage/databases/main/schema/delta/27/ts.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_origin_server_ts", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py
index b42c02710a..7f08fabe9f 100644 --- a/synapse/storage/databases/main/schema/delta/30/as_users.py +++ b/synapse/storage/databases/main/schema/delta/30/as_users.py
@@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( - database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name IN (%s)" - % (",".join("?" for _ in chunk),) - ), + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" + % (",".join("?" for _ in chunk),), [as_id] + chunk, ) diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py
index 9bb504aad5..5be81c806a 100644 --- a/synapse/storage/databases/main/schema/delta/31/pushers.py +++ b/synapse/storage/databases/main/schema/delta/31/pushers.py
@@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs): row = list(row) row[12] = token_to_stream_ordering(row[12]) cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_stream_ordering, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_stream_ordering, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py
index 63b757ade6..b84c844e3a 100644 --- a/synapse/storage/databases/main/schema/delta/31/search_update.py +++ b/synapse/storage/databases/main/schema/delta/31/search_update.py
@@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search_order", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py
index a3e81eeac7..e928c66a8f 100644 --- a/synapse/storage/databases/main/schema/delta/33/event_fields.py +++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_fields_sender_url", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..ad875c733a 100644 --- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute( - database_engine.convert_param_style( - "UPDATE remote_media_cache SET last_access_ts = ?" - ), - (int(time.time() * 1000),), + "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),), ) diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
index 1de8b54961..bb7296852a 100644 --- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py +++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
@@ -1,6 +1,8 @@ import logging +from io import StringIO from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import execute_statements_from_stream logger = logging.getLogger(__name__) @@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs): select_clause, ) - if isinstance(database_engine, PostgresEngine): - cur.execute(sql) - else: - cur.executescript(sql) + execute_statements_from_stream(cur, StringIO(sql)) diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
index 63b5acdcf7..44917f0a2e 100644 --- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py +++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
@@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): INNER JOIN room_memberships AS r USING (event_id) WHERE type = 'm.room.member' AND state_key LIKE ? """ - sql = database_engine.convert_param_style(sql) cur.execute(sql, ("%:" + config.server_name,)) cur.execute( diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
index b64926e9c9..3275ae2b20 100644 --- a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -20,14 +20,14 @@ */ -- add new index that includes method to local media -INSERT INTO background_updates (update_name, progress_json) VALUES - ('local_media_repository_thumbnails_method_idx', '{}'); +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5807, 'local_media_repository_thumbnails_method_idx', '{}'); -- add new index that includes method to remote media -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx'); +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx'); -- drop old index -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx'); +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx'); diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql new file mode 100644
index 0000000000..7851a0a825 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
@@ -0,0 +1,20 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS dehydrated_devices( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + device_data TEXT NOT NULL -- JSON-encoded client-defined data +); diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql new file mode 100644
index 0000000000..4ed981dbf8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
@@ -0,0 +1,24 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json ( + user_id TEXT NOT NULL, -- The user this fallback key is for. + device_id TEXT NOT NULL, -- The device this fallback key is for. + algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + key_json TEXT NOT NULL, -- The key as a JSON blob. + used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not. + CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm) +); diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
index cade5dcca8..fd733adf13 100644 --- a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql +++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
@@ -28,5 +28,5 @@ -- functionality as the old one. This effectively restarts the background job -- from the beginning, without running it twice in a row, supporting both -- upgrade usecases. -INSERT INTO background_updates (update_name, progress_json) VALUES - ('populate_stats_process_rooms_2', '{}'); +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5812, 'populate_stats_process_rooms_2', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres new file mode 100644
index 0000000000..841186b826 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
@@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A unique and immutable mapping between instance name and an integer ID. This +-- lets us refer to instances via a small ID in e.g. stream tokens, without +-- having to encode the full name. +CREATE TABLE IF NOT EXISTS instance_map ( + instance_id SERIAL PRIMARY KEY, + instance_name TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name); diff --git a/synapse/storage/databases/main/schema/delta/58/19txn_id.sql b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql new file mode 100644
index 0000000000..b2454121a8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql
@@ -0,0 +1,40 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A map of recent events persisted with transaction IDs. Used to deduplicate +-- send event requests with the same transaction ID. +-- +-- Note: transaction IDs are scoped to the room ID/user ID/access token that was +-- used to make the request. +-- +-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the +-- events or access token we don't want to try and de-duplicate the event. +CREATE TABLE IF NOT EXISTS event_txn_id ( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + token_id BIGINT NOT NULL, + txn_id TEXT NOT NULL, + inserted_ts BIGINT NOT NULL, + FOREIGN KEY (event_id) + REFERENCES events (event_id) ON DELETE CASCADE, + FOREIGN KEY (token_id) + REFERENCES access_tokens (id) ON DELETE CASCADE +); + +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id); +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id); +CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts); diff --git a/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql new file mode 100644
index 0000000000..ad1f481428 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql
@@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE current_state_delta_stream ADD COLUMN instance_name TEXT; +ALTER TABLE ex_outlier_stream ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql new file mode 100644
index 0000000000..b0b5dcddce --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql
@@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + -- Add new column to user_daily_visits to track user agent +ALTER TABLE user_daily_visits + ADD COLUMN user_agent TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql new file mode 100644
index 0000000000..7b84a207fd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql
@@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE application_services_state ADD COLUMN read_receipt_stream_id INT; +ALTER TABLE application_services_state ADD COLUMN presence_stream_id INT; \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql new file mode 100644
index 0000000000..01ea6eddcf --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql
@@ -0,0 +1 @@ +DROP TABLE device_max_stream_id; diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql new file mode 100644
index 0000000000..00a9431a97 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
@@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Whether the access token is an admin token for controlling another user. +ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql new file mode 100644
index 0000000000..e1a35be831 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
@@ -0,0 +1,2 @@ +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5822, 'users_have_local_media', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql new file mode 100644
index 0000000000..75c3915a94 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
@@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5823, 'e2e_cross_signing_keys_idx', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql new file mode 100644
index 0000000000..8a39d54aed --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
@@ -0,0 +1,19 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this index is essentially redundant. The only time it was ever used was when purging +-- rooms - and Synapse 1.24 will change that. + +DROP INDEX IF EXISTS event_json_room_id; diff --git a/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql new file mode 100644
index 0000000000..8f5e65aa71 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
@@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5825, 'user_external_ids_user_id_idx', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql new file mode 100644
index 0000000000..1a101cd5eb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql
@@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The last time this access token was "validated" (i.e. logged in or succeeded +-- at user-interactive authentication). +ALTER TABLE access_tokens ADD COLUMN last_validated BIGINT; diff --git a/synapse/storage/databases/main/schema/delta/58/27local_invites.sql b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql new file mode 100644
index 0000000000..44b2a0572f --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql
@@ -0,0 +1,18 @@ +/* + * Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This is unused since Synapse v1.17.0. +DROP TABLE local_invites; diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5beb302be3..0cdb3ec1f7 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -16,15 +16,18 @@ import logging from collections import Counter +from enum import Enum from itertools import chain from typing import Any, Dict, List, Optional, Tuple from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import StoreError from synapse.storage.database import DatabasePool from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -59,6 +62,23 @@ TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} +class UserSortOrder(Enum): + """ + Enum to define the sorting method used when returning users + with get_users_media_usage_paginate + + MEDIA_LENGTH = ordered by size of uploaded media. Smallest to largest. + MEDIA_COUNT = ordered by number of uploaded media. Smallest to largest. + USER_ID = ordered alphabetically by `user_id`. + DISPLAYNAME = ordered alphabetically by `displayname` + """ + + MEDIA_LENGTH = "media_length" + MEDIA_COUNT = "media_count" + USER_ID = "user_id" + DISPLAYNAME = "displayname" + + class StatsStore(StateDeltasStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -882,3 +902,110 @@ class StatsStore(StateDeltasStore): complete_with_stream_id=pos, absolute_field_overrides={"joined_rooms": joined_rooms}, ) + + async def get_users_media_usage_paginate( + self, + start: int, + limit: int, + from_ts: Optional[int] = None, + until_ts: Optional[int] = None, + order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value, + direction: Optional[str] = "f", + search_term: Optional[str] = None, + ) -> Tuple[List[JsonDict], Dict[str, int]]: + """Function to retrieve a paginated list of users and their uploaded local media + (size and number). This will return a json list of users and the + total number of users matching the filter criteria. + + Args: + start: offset to begin the query from + limit: number of rows to retrieve + from_ts: request only media that are created later than this timestamp (ms) + until_ts: request only media that are created earlier than this timestamp (ms) + order_by: the sort order of the returned list + direction: sort ascending or descending + search_term: a string to filter user names by + Returns: + A list of user dicts and an integer representing the total number of + users that exist given this query + """ + + def get_users_media_usage_paginate_txn(txn): + filters = [] + args = [self.hs.config.server_name] + + if search_term: + filters.append("(lmr.user_id LIKE ? OR displayname LIKE ?)") + args.extend(["@%" + search_term + "%:%", "%" + search_term + "%"]) + + if from_ts: + filters.append("created_ts >= ?") + args.extend([from_ts]) + if until_ts: + filters.append("created_ts <= ?") + args.extend([until_ts]) + + # Set ordering + if UserSortOrder(order_by) == UserSortOrder.MEDIA_LENGTH: + order_by_column = "media_length" + elif UserSortOrder(order_by) == UserSortOrder.MEDIA_COUNT: + order_by_column = "media_count" + elif UserSortOrder(order_by) == UserSortOrder.USER_ID: + order_by_column = "lmr.user_id" + elif UserSortOrder(order_by) == UserSortOrder.DISPLAYNAME: + order_by_column = "displayname" + else: + raise StoreError( + 500, "Incorrect value for order_by provided: %s" % order_by + ) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql_base = """ + FROM local_media_repository as lmr + LEFT JOIN profiles AS p ON lmr.user_id = '@' || p.user_id || ':' || ? + {} + GROUP BY lmr.user_id, displayname + """.format( + where_clause + ) + + # SQLite does not support SELECT COUNT(*) OVER() + sql = """ + SELECT COUNT(*) FROM ( + SELECT lmr.user_id + {sql_base} + ) AS count_user_ids + """.format( + sql_base=sql_base, + ) + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = """ + SELECT + lmr.user_id, + displayname, + COUNT(lmr.user_id) as media_count, + SUM(media_length) as media_length + {sql_base} + ORDER BY {order_by_column} {order} + LIMIT ? OFFSET ? + """.format( + sql_base=sql_base, order_by_column=order_by_column, order=order, + ) + + args += [limit, start] + txn.execute(sql, args) + users = self.db_pool.cursor_to_dict(txn) + + return users, count + + return await self.db_pool.runInteraction( + "get_users_media_usage_paginate_txn", get_users_media_usage_paginate_txn + ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 37249f1e3f..e3b9ff5ca6 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -53,7 +53,9 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import Collection, PersistedEventPosition, RoomStreamToken +from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache if TYPE_CHECKING: @@ -208,6 +210,55 @@ def _make_generic_sql_bound( ) +def _filter_results( + lower_token: Optional[RoomStreamToken], + upper_token: Optional[RoomStreamToken], + instance_name: str, + topological_ordering: int, + stream_ordering: int, +) -> bool: + """Returns True if the event persisted by the given instance at the given + topological/stream_ordering falls between the two tokens (taking a None + token to mean unbounded). + + Used to filter results from fetching events in the DB against the given + tokens. This is necessary to handle the case where the tokens include + position maps, which we handle by fetching more than necessary from the DB + and then filtering (rather than attempting to construct a complicated SQL + query). + """ + + event_historical_tuple = ( + topological_ordering, + stream_ordering, + ) + + if lower_token: + if lower_token.topological is not None: + # If these are historical tokens we compare the `(topological, stream)` + # tuples. + if event_historical_tuple <= lower_token.as_historical_tuple(): + return False + + else: + # If these are live tokens we compare the stream ordering against the + # writers stream position. + if stream_ordering <= lower_token.get_stream_pos_for_instance( + instance_name + ): + return False + + if upper_token: + if upper_token.topological is not None: + if upper_token.as_historical_tuple() < event_historical_tuple: + return False + else: + if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering: + return False + + return True + + def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create @@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): raise NotImplementedError() def get_room_max_token(self) -> RoomStreamToken: - return RoomStreamToken(None, self.get_room_max_stream_ordering()) + """Get a `RoomStreamToken` that marks the current maximum persisted + position of the events stream. Useful to get a token that represents + "now". + + The token returned is a "live" token that may have an instance_map + component. + """ + + min_pos = self._stream_id_gen.get_current_token() + + positions = {} + if isinstance(self._stream_id_gen, MultiWriterIdGenerator): + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + positions = { + i: p + for i, p in self._stream_id_gen.get_positions().items() + if p > min_pos + } + + return RoomStreamToken(None, min_pos, positions) async def get_room_events_stream_for_rooms( self, @@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if from_key == to_key: return [], from_key - from_id = from_key.stream - to_id = to_key.stream - - has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) + has_changed = self._events_stream_cache.has_entity_changed( + room_id, from_key.stream + ) if not has_changed: return [], from_key def f(txn): - sql = ( - "SELECT event_id, stream_ordering FROM events WHERE" - " room_id = ?" - " AND not outlier" - " AND stream_ordering > ? AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) - txn.execute(sql, (room_id, from_id, to_id, limit)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT event_id, instance_name, topological_ordering, stream_ordering + FROM events + WHERE + room_id = ? + AND not outlier + AND stream_ordering > ? AND stream_ordering <= ? + ORDER BY stream_ordering %s LIMIT ? + """ % ( + order, + ) + txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ][:limit] return rows rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) @@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): [r.event_id for r in rows], get_prev_content=True ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._set_before_and_after(ret, rows, topo_order=False) if order.lower() == "desc": ret.reverse() @@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - from_id = from_key.stream - to_id = to_key.stream - if from_key == to_key: return [] - if from_id: + if from_key: has_changed = self._membership_stream_cache.has_entity_changed( - user_id, int(from_id) + user_id, int(from_key.stream) ) if not has_changed: return [] def f(txn): - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND e.stream_ordering > ? AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - ) - txn.execute(sql, (user_id, from_id, to_id)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT m.event_id, instance_name, topological_ordering, stream_ordering + FROM events AS e, room_memberships AS m + WHERE e.event_id = m.event_id + AND m.user_id = ? + AND e.stream_ordering > ? AND e.stream_ordering <= ? + ORDER BY e.stream_ordering ASC + """ + txn.execute(sql, (user_id, min_from_id, max_to_id,)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ] return rows @@ -546,7 +651,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_room_event_before_stream_ordering( self, room_id: str, stream_ordering: int - ) -> Tuple[int, int, str]: + ) -> Optional[Tuple[int, int, str]]: """Gets details of the first event in a room at or before a stream ordering Args: @@ -589,19 +694,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): ) return "t%d-%d" % (topo, token) - async def get_stream_id_for_event(self, event_id: str) -> int: - """The stream ID for an event - Args: - event_id: The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A stream ID. - """ - return await self.db_pool.runInteraction( - "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id, - ) - def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, allow_none=False, ) -> int: @@ -979,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): else: order = "ASC" + # The bounds for the stream tokens are complicated by the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + if from_token.topological is not None: + from_bound = ( + from_token.as_historical_tuple() + ) # type: Tuple[Optional[int], int] + elif direction == "b": + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound = None # type: Optional[Tuple[Optional[int], int]] + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_historical_tuple() + elif direction == "b": + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.as_tuple(), - to_token=to_token.as_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) @@ -993,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): bounds += " AND " + filter_clause args.extend(filter_args) - args.append(int(limit)) + # We fetch more events as we'll filter the result set + args.append(int(limit) * 2) select_keywords = "SELECT" join_clause = "" @@ -1015,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): select_keywords += "DISTINCT" sql = """ - %(select_keywords)s event_id, topological_ordering, stream_ordering + %(select_keywords)s + event_id, instance_name, + topological_ordering, stream_ordering FROM events %(join_clause)s WHERE outlier = ? AND room_id = ? AND %(bounds)s @@ -1030,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): txn.execute(sql, args) - rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] + # Filter the result set. + rows = [ + _EventDictReturn(event_id, topological_ordering, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + lower_token=to_token if direction == "b" else from_token, + upper_token=from_token if direction == "b" else to_token, + instance_name=instance_name, + topological_ordering=topological_ordering, + stream_ordering=stream_ordering, + ) + ][:limit] if rows: topo = rows[-1].topological_ordering @@ -1095,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): return (events, token) + @cached() + async def get_id_for_instance(self, instance_name: str) -> int: + """Get a unique, immutable ID that corresponds to the given Synapse worker instance. + """ + + def _get_id_for_instance_txn(txn): + instance_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + allow_none=True, + ) + if instance_id is not None: + return instance_id + + # If we don't have an entry upsert one. + # + # We could do this before the first check, and rely on the cache for + # efficiency, but each UPSERT causes the next ID to increment which + # can quickly bloat the size of the generated IDs for new instances. + self.db_pool.simple_upsert_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + values={}, + ) + + return self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + ) + + return await self.db_pool.runInteraction( + "get_id_for_instance", _get_id_for_instance_txn + ) + + @cached() + async def get_name_from_instance_id(self, instance_id: int) -> str: + """Get the instance name from an ID previously returned by + `get_id_for_instance`. + """ + + return await self.db_pool.simple_select_one_onecol( + table="instance_map", + keyvalues={"instance_id": instance_id}, + retcol="instance_name", + desc="get_name_from_instance_id", + ) + class StreamStore(StreamWorkerStore): def get_room_max_stream_ordering(self) -> int: diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 97aed1500e..59207cadd4 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -43,15 +43,33 @@ _UpdateTransactionRow = namedtuple( SENTINEL = object() -class TransactionStore(SQLBaseStore): +class TransactionWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + if hs.config.run_background_tasks: + self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) + + @wrap_as_background_process("cleanup_transactions") + async def _cleanup_transactions(self) -> None: + now = self._clock.time_msec() + month_ago = now - 30 * 24 * 60 * 60 * 1000 + + def _cleanup_transactions_txn(txn): + txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + + await self.db_pool.runInteraction( + "_cleanup_transactions", _cleanup_transactions_txn + ) + + +class TransactionStore(TransactionWorkerStore): """A collection of queries for handling PDUs. """ def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) - self._destination_retry_cache = ExpiringCache( cache_name="get_destination_retry_timings", clock=self._clock, @@ -190,42 +208,56 @@ class TransactionStore(SQLBaseStore): """ self._destination_retry_cache.pop(destination, None) - return await self.db_pool.runInteraction( - "set_destination_retry_timings", - self._set_destination_retry_timings, - destination, - failure_ts, - retry_last_ts, - retry_interval, - ) + if self.database_engine.can_native_upsert: + return await self.db_pool.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings_native, + destination, + failure_ts, + retry_last_ts, + retry_interval, + db_autocommit=True, # Safe as its a single upsert + ) + else: + return await self.db_pool.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings_emulated, + destination, + failure_ts, + retry_last_ts, + retry_interval, + ) - def _set_destination_retry_timings( + def _set_destination_retry_timings_native( self, txn, destination, failure_ts, retry_last_ts, retry_interval ): + assert self.database_engine.can_native_upsert + + # Upsert retry time interval if retry_interval is zero (i.e. we're + # resetting it) or greater than the existing retry interval. + # + # WARNING: This is executed in autocommit, so we shouldn't add any more + # SQL calls in here (without being very careful). + sql = """ + INSERT INTO destinations ( + destination, failure_ts, retry_last_ts, retry_interval + ) + VALUES (?, ?, ?, ?) + ON CONFLICT (destination) DO UPDATE SET + failure_ts = EXCLUDED.failure_ts, + retry_last_ts = EXCLUDED.retry_last_ts, + retry_interval = EXCLUDED.retry_interval + WHERE + EXCLUDED.retry_interval = 0 + OR destinations.retry_interval IS NULL + OR destinations.retry_interval < EXCLUDED.retry_interval + """ - if self.database_engine.can_native_upsert: - # Upsert retry time interval if retry_interval is zero (i.e. we're - # resetting it) or greater than the existing retry interval. - - sql = """ - INSERT INTO destinations ( - destination, failure_ts, retry_last_ts, retry_interval - ) - VALUES (?, ?, ?, ?) - ON CONFLICT (destination) DO UPDATE SET - failure_ts = EXCLUDED.failure_ts, - retry_last_ts = EXCLUDED.retry_last_ts, - retry_interval = EXCLUDED.retry_interval - WHERE - EXCLUDED.retry_interval = 0 - OR destinations.retry_interval IS NULL - OR destinations.retry_interval < EXCLUDED.retry_interval - """ - - txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) - - return + txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) + def _set_destination_retry_timings_emulated( + self, txn, destination, failure_ts, retry_last_ts, retry_interval + ): self.database_engine.lock_table(txn, "destinations") # We need to be careful here as the data may have changed from under us @@ -266,22 +298,6 @@ class TransactionStore(SQLBaseStore): }, ) - def _start_cleanup_transactions(self): - return run_as_background_process( - "cleanup_transactions", self._cleanup_transactions - ) - - async def _cleanup_transactions(self) -> None: - now = self._clock.time_msec() - month_ago = now - 30 * 24 * 60 * 60 * 1000 - - def _cleanup_transactions_txn(txn): - txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - - await self.db_pool.runInteraction( - "_cleanup_transactions", _cleanup_transactions_txn - ) - async def store_destination_rooms_entries( self, destinations: Iterable[str], room_id: str, stream_ordering: int, ) -> None: diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 3b9211a6d2..79b7ece330 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py
@@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore): ) return [(row["user_agent"], row["ip"]) for row in rows] - -class UIAuthStore(UIAuthWorkerStore): async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore): iterable=session_ids, keyvalues={}, ) + + +class UIAuthStore(UIAuthWorkerStore): + pass diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 5a390ff2f6..ef11f1c3b3 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -17,7 +17,7 @@ import logging import re from typing import Any, Dict, Iterable, Optional, Set, Tuple -from synapse.api.constants import EventTypes, JoinRules +from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules from synapse.storage.database import DatabasePool from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state_deltas import StateDeltasStore @@ -360,7 +360,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): if hist_vis_id: hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) if hist_vis_ev: - if hist_vis_ev.content.get("history_visibility") == "world_readable": + if ( + hist_vis_ev.content.get("history_visibility") + == HistoryVisibility.WORLD_READABLE + ): return True return False @@ -393,9 +396,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): sql = """ INSERT INTO user_directory_search(user_id, vector) VALUES (?, - setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + setweight(to_tsvector('simple', ?), 'A') + || setweight(to_tsvector('simple', ?), 'D') + || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector """ txn.execute( @@ -415,9 +418,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): sql = """ INSERT INTO user_directory_search(user_id, vector) VALUES (?, - setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + setweight(to_tsvector('simple', ?), 'A') + || setweight(to_tsvector('simple', ?), 'D') + || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') ) """ txn.execute( @@ -432,9 +435,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): elif new_entry is False: sql = """ UPDATE user_directory_search - SET vector = setweight(to_tsvector('english', ?), 'A') - || setweight(to_tsvector('english', ?), 'D') - || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + SET vector = setweight(to_tsvector('simple', ?), 'A') + || setweight(to_tsvector('simple', ?), 'D') + || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') WHERE user_id = ? """ txn.execute( @@ -480,21 +483,16 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_id_tuples: iterable of 2-tuple of user IDs. """ - def _add_users_who_share_room_txn(txn): - self.db_pool.simple_upsert_many_txn( - txn, - table="users_who_share_private_rooms", - key_names=["user_id", "other_user_id", "room_id"], - key_values=[ - (user_id, other_user_id, room_id) - for user_id, other_user_id in user_id_tuples - ], - value_names=(), - value_values=None, - ) - - await self.db_pool.runInteraction( - "add_users_who_share_room", _add_users_who_share_room_txn + await self.db_pool.simple_upsert_many( + table="users_who_share_private_rooms", + key_names=["user_id", "other_user_id", "room_id"], + key_values=[ + (user_id, other_user_id, room_id) + for user_id, other_user_id in user_id_tuples + ], + value_names=(), + value_values=None, + desc="add_users_who_share_room", ) async def add_users_in_public_rooms( @@ -508,19 +506,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_ids """ - def _add_users_in_public_rooms_txn(txn): - - self.db_pool.simple_upsert_many_txn( - txn, - table="users_in_public_rooms", - key_names=["user_id", "room_id"], - key_values=[(user_id, room_id) for user_id in user_ids], - value_names=(), - value_values=None, - ) - - await self.db_pool.runInteraction( - "add_users_in_public_rooms", _add_users_in_public_rooms_txn + await self.db_pool.simple_upsert_many( + table="users_in_public_rooms", + key_names=["user_id", "room_id"], + key_values=[(user_id, room_id) for user_id in user_ids], + value_names=(), + value_values=None, + desc="add_users_in_public_rooms", ) async def delete_all_from_user_dir(self) -> None: @@ -772,7 +764,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): INNER JOIN user_directory AS d USING (user_id) WHERE %s - AND vector @@ to_tsquery('english', ?) + AND vector @@ to_tsquery('simple', ?) ORDER BY (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END) @@ -781,13 +773,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): 3 * ts_rank_cd( '{0.1, 0.1, 0.9, 1.0}', vector, - to_tsquery('english', ?), + to_tsquery('simple', ?), 8 ) + ts_rank_cd( '{0.1, 0.1, 0.9, 1.0}', vector, - to_tsquery('english', ?), + to_tsquery('simple', ?), 8 ) ) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index afd10f7bae..c03871f393 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py
@@ -17,11 +17,12 @@ import logging import attr +from signedjson.types import VerifyKey logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True) class FetchKeyResult: - verify_key = attr.ib() # VerifyKey: the key itself - valid_until_ts = attr.ib() # int: how long we can use this key for + verify_key = attr.ib(type=VerifyKey) # the key itself + valid_until_ts = attr.ib(type=int) # how long we can use this key for diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 72939f3984..61fc49c69c 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py
@@ -31,7 +31,14 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState -from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap +from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.types import ( + Collection, + PersistedEventPosition, + RoomStreamToken, + StateMap, + get_domain_from_id, +) from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -68,6 +75,21 @@ stale_forward_extremities_counter = Histogram( buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), ) +state_resolutions_during_persistence = Counter( + "synapse_storage_events_state_resolutions_during_persistence", + "Number of times we had to do state res to calculate new current state", +) + +potential_times_prune_extremities = Counter( + "synapse_storage_events_potential_times_prune_extremities", + "Number of times we might be able to prune extremities", +) + +times_pruned_extremities = Counter( + "synapse_storage_events_times_pruned_extremities", + "Number of times we were actually be able to prune extremities", +) + class _EventPeristenceQueue: """Queues up events so that they can be persisted in bulk with only one @@ -96,7 +118,9 @@ class _EventPeristenceQueue: Returns: defer.Deferred: a deferred which will resolve once the events are - persisted. Runs its callbacks *without* a logcontext. + persisted. Runs its callbacks *without* a logcontext. The result + is the same as that returned by the callback passed to + `handle_queue`. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: @@ -199,7 +223,7 @@ class EventsPersistenceStorage: self, events_and_contexts: Iterable[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> RoomStreamToken: + ) -> Tuple[List[EventBase], RoomStreamToken]: """ Write events to the database Args: @@ -209,7 +233,11 @@ class EventsPersistenceStorage: which might update the current state etc. Returns: - the stream ordering of the latest persisted event + List of events persisted, the current position room stream position. + The list of events persisted may not be the same as those passed in + if they were deduplicated due to an event already existing that + matched the transcation ID; the existing event is returned in such + a case. """ partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]] for event, ctx in events_and_contexts: @@ -225,19 +253,41 @@ class EventsPersistenceStorage: for room_id in partitioned: self._maybe_start_persisting(room_id) - await make_deferred_yieldable( + # Each deferred returns a map from event ID to existing event ID if the + # event was deduplicated. (The dict may also include other entries if + # the event was persisted in a batch with other events). + # + # Since we use `defer.gatherResults` we need to merge the returned list + # of dicts into one. + ret_vals = await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) + replaced_events = {} + for d in ret_vals: + replaced_events.update(d) + + events = [] + for event, _ in events_and_contexts: + existing_event_id = replaced_events.get(event.event_id) + if existing_event_id: + events.append(await self.main_store.get_event(existing_event_id)) + else: + events.append(event) - return self.main_store.get_room_max_token() + return ( + events, + self.main_store.get_room_max_token(), + ) async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False - ) -> Tuple[PersistedEventPosition, RoomStreamToken]: + ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]: """ Returns: - The stream ordering of `event`, and the stream ordering of the - latest persisted event + The event, stream ordering of `event`, and the stream ordering of the + latest persisted event. The returned event may not match the given + event if it was deduplicated due to an existing event matching the + transaction ID. """ deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled @@ -245,17 +295,33 @@ class EventsPersistenceStorage: self._maybe_start_persisting(event.room_id) - await make_deferred_yieldable(deferred) + # The deferred returns a map from event ID to existing event ID if the + # event was deduplicated. (The dict may also include other entries if + # the event was persisted in a batch with other events.) + replaced_events = await make_deferred_yieldable(deferred) + replaced_event = replaced_events.get(event.event_id) + if replaced_event: + event = await self.main_store.get_event(replaced_event) event_stream_id = event.internal_metadata.stream_ordering + # stream ordering should have been assigned by now + assert event_stream_id pos = PersistedEventPosition(self._instance_name, event_stream_id) - return pos, self.main_store.get_room_max_token() + return event, pos, self.main_store.get_room_max_token() def _maybe_start_persisting(self, room_id: str): + """Pokes the `_event_persist_queue` to start handling new items in the + queue, if not already in progress. + + Causes the deferreds returned by `add_to_queue` to resolve with: a + dictionary of event ID to event ID we didn't persist as we already had + another event persisted with the same TXN ID. + """ + async def persisting_queue(item): with Measure(self._clock, "persist_events"): - await self._persist_events( + return await self._persist_events( item.events_and_contexts, backfilled=item.backfilled ) @@ -265,12 +331,38 @@ class EventsPersistenceStorage: self, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ): + ) -> Dict[str, str]: """Calculates the change to current state and forward extremities, and persists the given events and with those updates. + + Returns: + A dictionary of event ID to event ID we didn't persist as we already + had another event persisted with the same TXN ID. """ + replaced_events = {} # type: Dict[str, str] if not events_and_contexts: - return + return replaced_events + + # Check if any of the events have a transaction ID that has already been + # persisted, and if so we don't persist it again. + # + # We should have checked this a long time before we get here, but it's + # possible that different send event requests race in such a way that + # they both pass the earlier checks. Checking here isn't racey as we can + # have only one `_persist_events` per room being called at a time. + replaced_events = await self.main_store.get_already_persisted_events( + (event for event, _ in events_and_contexts) + ) + + if replaced_events: + events_and_contexts = [ + (e, ctx) + for e, ctx in events_and_contexts + if e.event_id not in replaced_events + ] + + if not events_and_contexts: + return replaced_events chunks = [ events_and_contexts[x : x + 100] @@ -384,7 +476,15 @@ class EventsPersistenceStorage: latest_event_ids, new_latest_event_ids, ) - current_state, delta_ids = res + current_state, delta_ids, new_latest_event_ids = res + + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + + new_forward_extremeties[room_id] = new_latest_event_ids # If either are not None then there has been a change, # and we need to work out the delta (or use that @@ -439,6 +539,8 @@ class EventsPersistenceStorage: await self._handle_potentially_left_users(potentially_left_users) + return replaced_events + async def _calculate_new_extremities( self, room_id: str, @@ -501,29 +603,35 @@ class EventsPersistenceStorage: self, room_id: str, events_context: List[Tuple[EventBase, EventContext]], - old_latest_event_ids: Iterable[str], - new_latest_event_ids: Iterable[str], - ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: + old_latest_event_ids: Set[str], + new_latest_event_ids: Set[str], + ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]: """Calculate the current state dict after adding some new events to a room Args: - room_id (str): + room_id: room to which the events are being added. Used for logging etc - events_context (list[(EventBase, EventContext)]): + events_context: events and contexts which are being added to the room - old_latest_event_ids (iterable[str]): + old_latest_event_ids: the old forward extremities for the room. - new_latest_event_ids (iterable[str]): + new_latest_event_ids : the new forward extremities for the room. Returns: - Returns a tuple of two state maps, the first being the full new current - state and the second being the delta to the existing current state. - If both are None then there has been no change. + Returns a tuple of two state maps and a set of new forward + extremities. + + The first state map is the full new current state and the second + is the delta to the existing current state. If both are None then + there has been no change. + + The function may prune some old entries from the set of new + forward extremities if it's safe to do so. If there has been a change then we only return the delta if its already been calculated. Conversely if we do know the delta then @@ -600,7 +708,7 @@ class EventsPersistenceStorage: # If they old and new groups are the same then we don't need to do # anything. if old_state_groups == new_state_groups: - return None, None + return None, None, new_latest_event_ids if len(new_state_groups) == 1 and len(old_state_groups) == 1: # If we're going from one state group to another, lets check if @@ -617,7 +725,7 @@ class EventsPersistenceStorage: # the current state in memory then lets also return that, # but it doesn't matter if we don't. new_state = state_groups_map.get(new_state_group) - return new_state, delta_ids + return new_state, delta_ids, new_latest_event_ids # Now that we have calculated new_state_groups we need to get # their state IDs so we can resolve to a single state set. @@ -629,7 +737,7 @@ class EventsPersistenceStorage: if len(new_state_groups) == 1: # If there is only one state group, then we know what the current # state is. - return state_groups_map[new_state_groups.pop()], None + return state_groups_map[new_state_groups.pop()], None, new_latest_event_ids # Ok, we need to defer to the state handler to resolve our state sets. @@ -662,7 +770,139 @@ class EventsPersistenceStorage: state_res_store=StateResolutionStore(self.main_store), ) - return res.state, None + state_resolutions_during_persistence.inc() + + # If the returned state matches the state group of one of the new + # forward extremities then we check if we are able to prune some state + # extremities. + if res.state_group and res.state_group in new_state_groups: + new_latest_event_ids = await self._prune_extremities( + room_id, + new_latest_event_ids, + res.state_group, + event_id_to_state_group, + events_context, + ) + + return res.state, None, new_latest_event_ids + + async def _prune_extremities( + self, + room_id: str, + new_latest_event_ids: Set[str], + resolved_state_group: int, + event_id_to_state_group: Dict[str, int], + events_context: List[Tuple[EventBase, EventContext]], + ) -> Set[str]: + """See if we can prune any of the extremities after calculating the + resolved state. + """ + potential_times_prune_extremities.inc() + + # We keep all the extremities that have the same state group, and + # see if we can drop the others. + new_new_extrems = { + e + for e in new_latest_event_ids + if event_id_to_state_group[e] == resolved_state_group + } + + dropped_extrems = set(new_latest_event_ids) - new_new_extrems + + logger.debug("Might drop extremities: %s", dropped_extrems) + + # We only drop events from the extremities list if: + # 1. we're not currently persisting them; + # 2. they're not our own events (or are dummy events); and + # 3. they're either: + # 1. over N hours old and more than N events ago (we use depth to + # calculate); or + # 2. we are persisting an event from the same domain and more than + # M events ago. + # + # The idea is that we don't want to drop events that are "legitimate" + # extremities (that we would want to include as prev events), only + # "stuck" extremities that are e.g. due to a gap in the graph. + # + # Note that we either drop all of them or none of them. If we only drop + # some of the events we don't know if state res would come to the same + # conclusion. + + for ev, _ in events_context: + if ev.event_id in dropped_extrems: + logger.debug( + "Not dropping extremities: %s is being persisted", ev.event_id + ) + return new_latest_event_ids + + dropped_events = await self.main_store.get_events( + dropped_extrems, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + + new_senders = {get_domain_from_id(e.sender) for e, _ in events_context} + + one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 + current_depth = max(e.depth for e, _ in events_context) + for event in dropped_events.values(): + # If the event is a local dummy event then we should check it + # doesn't reference any local events, as we want to reference those + # if we send any new events. + # + # Note we do this recursively to handle the case where a dummy event + # references a dummy event that only references remote events. + # + # Ideally we'd figure out a way of still being able to drop old + # dummy events that reference local events, but this is good enough + # as a first cut. + events_to_check = [event] + while events_to_check: + new_events = set() + for event_to_check in events_to_check: + if self.is_mine_id(event_to_check.sender): + if event_to_check.type != EventTypes.Dummy: + logger.debug("Not dropping own event") + return new_latest_event_ids + new_events.update(event_to_check.prev_event_ids()) + + prev_events = await self.main_store.get_events( + new_events, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + events_to_check = prev_events.values() + + if ( + event.origin_server_ts < one_day_ago + and event.depth < current_depth - 100 + ): + continue + + # We can be less conservative about dropping extremities from the + # same domain, though we do want to wait a little bit (otherwise + # we'll immediately remove all extremities from a given server). + if ( + get_domain_from_id(event.sender) in new_senders + and event.depth < current_depth - 20 + ): + continue + + logger.debug( + "Not dropping as too new and not in new_senders: %s", new_senders, + ) + + return new_latest_event_ids + + times_pruned_extremities.inc() + + logger.info( + "Pruning forward extremities in room %s: from %s -> %s", + room_id, + new_latest_event_ids, + new_new_extrems, + ) + return new_new_extrems async def _calculate_state_delta( self, room_id: str, current_state: StateMap[str] diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 4957e77f4c..f91a2eae7a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py
@@ -13,20 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import imp import logging import os import re from collections import Counter -from typing import Optional, TextIO +from typing import Generator, Iterable, List, Optional, TextIO, Tuple import attr +from typing_extensions import Counter as CounterType from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines.postgres import PostgresEngine -from synapse.storage.types import Connection, Cursor +from synapse.storage.types import Cursor from synapse.types import Collection logger = logging.getLogger(__name__) @@ -67,10 +68,10 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = ( def prepare_database( - db_conn: Connection, + db_conn: LoggingDatabaseConnection, database_engine: BaseDatabaseEngine, config: Optional[HomeServerConfig], - databases: Collection[str] = ["main", "state"], + databases: Collection[str] = ("main", "state"), ): """Prepares a physical database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -89,7 +90,7 @@ def prepare_database( """ try: - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="prepare_database") # sqlite does not automatically start transactions for DDL / SELECT statements, # so we start one before running anything. This ensures that any upgrades @@ -155,7 +156,9 @@ def prepare_database( raise -def _setup_new_database(cur, database_engine, databases): +def _setup_new_database( + cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str] +) -> None: """Sets up the physical database by finding a base set of "full schemas" and then applying any necessary deltas, including schemas from the given data stores. @@ -188,10 +191,9 @@ def _setup_new_database(cur, database_engine, databases): folder as well those in the data stores specified. Args: - cur (Cursor): a database cursor - database_engine (DatabaseEngine) - databases (list[str]): The names of the databases to instantiate - on the given physical database. + cur: a database cursor + database_engine + databases: The names of the databases to instantiate on the given physical database. """ # We're about to set up a brand new database so we check that its @@ -199,12 +201,11 @@ def _setup_new_database(cur, database_engine, databases): database_engine.check_new_database(cur) current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) # First we find the highest full schema version we have valid_versions = [] - for filename in directory_entries: + for filename in os.listdir(current_dir): try: ver = int(filename) except ValueError: @@ -237,7 +238,7 @@ def _setup_new_database(cur, database_engine, databases): for database in databases ) - directory_entries = [] + directory_entries = [] # type: List[_DirectoryListing] for directory in directories: directory_entries.extend( _DirectoryListing(file_name, os.path.join(directory, file_name)) @@ -258,9 +259,7 @@ def _setup_new_database(cur, database_engine, databases): executescript(cur, entry.absolute_path) cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" - ), + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", (max_current_ver, False), ) @@ -277,15 +276,15 @@ def _setup_new_database(cur, database_engine, databases): def _upgrade_existing_database( - cur, - current_version, - applied_delta_files, - upgraded, - database_engine, - config, - databases, - is_empty=False, -): + cur: Cursor, + current_version: int, + applied_delta_files: List[str], + upgraded: bool, + database_engine: BaseDatabaseEngine, + config: Optional[HomeServerConfig], + databases: Collection[str], + is_empty: bool = False, +) -> None: """Upgrades an existing physical database. Delta files can either be SQL stored in *.sql files, or python modules @@ -325,21 +324,20 @@ def _upgrade_existing_database( for a version before applying those in the next version. Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having + cur + current_version: The current version of the schema. + applied_delta_files: A list of deltas that have already been applied. + upgraded: Whether the current version was generated by having applied deltas or from full schema file. If `True` the function will never apply delta files for the given `current_version`, since the current_version wasn't generated by applying those delta files. - database_engine (DatabaseEngine) - config (synapse.config.homeserver.HomeServerConfig|None): + database_engine + config: None if we are initialising a blank database, otherwise the application config - databases (list[str]): The names of the databases to instantiate + databases: The names of the databases to instantiate on the given physical database. - is_empty (bool): Is this a blank database? I.e. do we need to run the + is_empty: Is this a blank database? I.e. do we need to run the upgrade portions of the delta scripts. """ if is_empty: @@ -360,6 +358,7 @@ def _upgrade_existing_database( if not is_empty and "main" in databases: from synapse.storage.databases.main import check_database_before_upgrade + assert config is not None check_database_before_upgrade(cur, database_engine, config) start_ver = current_version @@ -390,10 +389,10 @@ def _upgrade_existing_database( ) # Used to check if we have any duplicate file names - file_name_counter = Counter() + file_name_counter = Counter() # type: CounterType[str] # Now find which directories have anything of interest. - directory_entries = [] + directory_entries = [] # type: List[_DirectoryListing] for directory in directories: logger.debug("Looking for schema deltas in %s", directory) try: @@ -447,11 +446,11 @@ def _upgrade_existing_database( module_name = "synapse.storage.v%d_%s" % (v, root_name) with open(absolute_path) as python_file: - module = imp.load_source(module_name, absolute_path, python_file) + module = imp.load_source(module_name, absolute_path, python_file) # type: ignore logger.info("Running script %s", relative_path) - module.run_create(cur, database_engine) + module.run_create(cur, database_engine) # type: ignore if not is_empty: - module.run_upgrade(cur, database_engine, config=config) + module.run_upgrade(cur, database_engine, config=config) # type: ignore elif ext == ".pyc" or file_name == "__pycache__": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package @@ -486,31 +485,28 @@ def _upgrade_existing_database( # Mark as done. cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)" - ), + "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)", (v, relative_path), ) cur.execute("DELETE FROM schema_version") cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" - ), + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", (v, True), ) logger.info("Schema now up to date") -def _apply_module_schemas(txn, database_engine, config): +def _apply_module_schemas( + txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig +) -> None: """Apply the module schemas for the dynamic modules, if any Args: cur: database cursor - database_engine: synapse database engine class - config (synapse.config.homeserver.HomeServerConfig): - application config + database_engine: + config: application config """ for (mod, _config) in config.password_providers: if not hasattr(mod, "get_db_schema_files"): @@ -521,21 +517,22 @@ def _apply_module_schemas(txn, database_engine, config): ) -def _apply_module_schema_files(cur, database_engine, modname, names_and_streams): +def _apply_module_schema_files( + cur: Cursor, + database_engine: BaseDatabaseEngine, + modname: str, + names_and_streams: Iterable[Tuple[str, TextIO]], +) -> None: """Apply the module schemas for a single module Args: cur: database cursor database_engine: synapse database engine class - modname (str): fully qualified name of the module - names_and_streams (Iterable[(str, file)]): the names and streams of - schemas to be applied + modname: fully qualified name of the module + names_and_streams: the names and streams of schemas to be applied """ cur.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_module_schemas WHERE module_name = ?" - ), - (modname,), + "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), ) applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: @@ -553,14 +550,12 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) # Mark as done. cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)" - ), + "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)", (modname, name), ) -def get_statements(f): +def get_statements(f: Iterable[str]) -> Generator[str, None, None]: statement_buffer = "" in_comment = False # If we're in a /* ... */ style comment @@ -605,17 +600,19 @@ def get_statements(f): statement_buffer = statements[-1].strip() -def executescript(txn, schema_path): +def executescript(txn: Cursor, schema_path: str) -> None: with open(schema_path, "r") as f: execute_statements_from_stream(txn, f) -def execute_statements_from_stream(cur: Cursor, f: TextIO): +def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None: for statement in get_statements(f): cur.execute(statement) -def _get_or_create_schema_state(txn, database_engine): +def _get_or_create_schema_state( + txn: Cursor, database_engine: BaseDatabaseEngine +) -> Optional[Tuple[int, List[str], bool]]: # Bluntly try creating the schema_version tables. schema_path = os.path.join(dir_path, "schema", "schema_version.sql") executescript(txn, schema_path) @@ -623,16 +620,14 @@ def _get_or_create_schema_state(txn, database_engine): txn.execute("SELECT version, upgraded FROM schema_version") row = txn.fetchone() current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None if current_version: txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), + "SELECT file FROM applied_schema_deltas WHERE version >= ?", (current_version,), ) applied_deltas = [d for d, in txn] + upgraded = bool(row[1]) return current_version, applied_deltas, upgraded return None @@ -647,5 +642,5 @@ class _DirectoryListing: `file_name` attr is kept first. """ - file_name = attr.ib() - absolute_path = attr.ib() + file_name = attr.ib(type=str) + absolute_path = attr.ib(type=str) diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index bfa0a9fd06..6c359c1aae 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py
@@ -15,7 +15,12 @@ import itertools import logging -from typing import Set +from typing import TYPE_CHECKING, Set + +from synapse.storage.databases import Databases + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -24,10 +29,10 @@ class PurgeEventsStorage: """High level interface for purging rooms and event history. """ - def __init__(self, hs, stores): + def __init__(self, hs: "HomeServer", stores: Databases): self.stores = stores - async def purge_room(self, room_id: str): + async def purge_room(self, room_id: str) -> None: """Deletes all record of a room """ diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index cec96ad6a7..2564f34b47 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py
@@ -14,10 +14,12 @@ # limitations under the License. import logging +from typing import Any, Dict, List, Optional, Tuple import attr from synapse.api.errors import SynapseError +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -27,18 +29,18 @@ class PaginationChunk: """Returned by relation pagination APIs. Attributes: - chunk (list): The rows returned by pagination - next_batch (Any|None): Token to fetch next set of results with, if + chunk: The rows returned by pagination + next_batch: Token to fetch next set of results with, if None then there are no more results. - prev_batch (Any|None): Token to fetch previous set of results with, if + prev_batch: Token to fetch previous set of results with, if None then there are no previous results. """ - chunk = attr.ib() - next_batch = attr.ib(default=None) - prev_batch = attr.ib(default=None) + chunk = attr.ib(type=List[JsonDict]) + next_batch = attr.ib(type=Optional[Any], default=None) + prev_batch = attr.ib(type=Optional[Any], default=None) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: d = {"chunk": self.chunk} if self.next_batch: @@ -59,25 +61,25 @@ class RelationPaginationToken: boundaries of the chunk as pagination tokens. Attributes: - topological (int): The topological ordering of the boundary event - stream (int): The stream ordering of the boundary event. + topological: The topological ordering of the boundary event + stream: The stream ordering of the boundary event. """ - topological = attr.ib() - stream = attr.ib() + topological = attr.ib(type=int) + stream = attr.ib(type=int) @staticmethod - def from_string(string): + def from_string(string: str) -> "RelationPaginationToken": try: t, s = string.split("-") return RelationPaginationToken(int(t), int(s)) except ValueError: raise SynapseError(400, "Invalid token") - def to_string(self): + def to_string(self) -> str: return "%d-%d" % (self.topological, self.stream) - def as_tuple(self): + def as_tuple(self) -> Tuple[Any, ...]: return attr.astuple(self) @@ -89,23 +91,23 @@ class AggregationPaginationToken: aggregation groups, we can just use them as our pagination token. Attributes: - count (int): The count of relations in the boundar group. - stream (int): The MAX stream ordering in the boundary group. + count: The count of relations in the boundary group. + stream: The MAX stream ordering in the boundary group. """ - count = attr.ib() - stream = attr.ib() + count = attr.ib(type=int) + stream = attr.ib(type=int) @staticmethod - def from_string(string): + def from_string(string: str) -> "AggregationPaginationToken": try: c, s = string.split("-") return AggregationPaginationToken(int(c), int(s)) except ValueError: raise SynapseError(400, "Invalid token") - def to_string(self): + def to_string(self) -> str: return "%d-%d" % (self.count, self.stream) - def as_tuple(self): + def as_tuple(self) -> Tuple[Any, ...]: return attr.astuple(self) diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 08a69f2f96..31ccbf23dc 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py
@@ -12,9 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Awaitable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, +) import attr @@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.types import MutableStateMap, StateMap +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.storage.databases import Databases + logger = logging.getLogger(__name__) # Used for generic functions below @@ -330,10 +343,12 @@ class StateGroupStorage: """High level interface to fetching state for event. """ - def __init__(self, hs, stores): + def __init__(self, hs: "HomeServer", stores: "Databases"): self.stores = stores - async def get_state_group_delta(self, state_group: int): + async def get_state_group_delta( + self, state_group: int + ) -> Tuple[Optional[int], Optional[StateMap[str]]]: """Given a state group try to return a previous group and a delta between the old and the new. @@ -341,8 +356,8 @@ class StateGroupStorage: state_group: The state group used to retrieve state deltas. Returns: - Tuple[Optional[int], Optional[StateMap[str]]]: - (prev_group, delta_ids) + A tuple of the previous group and a state map of the event IDs which + make up the delta between the old and new state groups. """ return await self.stores.state.get_state_group_delta(state_group) @@ -436,7 +451,7 @@ class StateGroupStorage: async def get_state_for_events( self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() - ): + ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -472,7 +487,7 @@ class StateGroupStorage: async def get_state_ids_for_events( self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() - ): + ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) @@ -500,7 +515,7 @@ class StateGroupStorage: async def get_state_for_event( self, event_id: str, state_filter: StateFilter = StateFilter.all() - ): + ) -> StateMap[EventBase]: """ Get the state dict corresponding to a particular event @@ -516,7 +531,7 @@ class StateGroupStorage: async def get_state_ids_for_event( self, event_id: str, state_filter: StateFilter = StateFilter.all() - ): + ) -> StateMap[str]: """ Get the state dict corresponding to a particular event diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 2d2b560e74..9cadcba18f 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Iterator, List, Tuple +from typing import Any, Iterable, Iterator, List, Optional, Tuple from typing_extensions import Protocol @@ -61,3 +61,9 @@ class Connection(Protocol): def rollback(self, *args, **kwargs) -> None: ... + + def __enter__(self) -> "Connection": + ... + + def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: + ... diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index eccd2d5b7b..133c0e7a28 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -55,7 +55,7 @@ def _load_current_id(db_conn, table, column, step=1): """ # debug logging for https://github.com/matrix-org/synapse/issues/7968 logger.info("initialising stream generator for %s(%s)", table, column) - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="_load_current_id") if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: @@ -153,12 +153,12 @@ class StreamIdGenerator: return _AsyncCtxManagerWrapper(manager()) - def get_current_token(self): + def get_current_token(self) -> int: """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. Returns: - int + The maximum stream id. """ with self._lock: if self._unfinished_ids: @@ -270,7 +270,7 @@ class MultiWriterIdGenerator: def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ): - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="_load_current_ids") # Load the current positions of all writers for the stream. if self._writers: @@ -284,15 +284,12 @@ class MultiWriterIdGenerator: stream_name = ? AND instance_name != ALL(?) """ - sql = self._db.engine.convert_param_style(sql) cur.execute(sql, (self._stream_name, self._writers)) sql = """ SELECT instance_name, stream_id FROM stream_positions WHERE stream_name = ? """ - sql = self._db.engine.convert_param_style(sql) - cur.execute(sql, (self._stream_name,)) self._current_positions = { @@ -341,7 +338,6 @@ class MultiWriterIdGenerator: "instance": instance_column, "cmp": "<=" if self._positive else ">=", } - sql = self._db.engine.convert_param_style(sql) cur.execute(sql, (min_stream_id * self._return_factor,)) self._persisted_upto_position = min_stream_id @@ -422,7 +418,7 @@ class MultiWriterIdGenerator: self._unfinished_ids.discard(next_id) self._finished_ids.add(next_id) - new_cur = None + new_cur = None # type: Optional[int] if self._unfinished_ids: # If there are unfinished IDs then the new position will be the @@ -528,6 +524,16 @@ class MultiWriterIdGenerator: heapq.heappush(self._known_persisted_positions, new_id) + # If we're a writer and we don't have any active writes we update our + # current position to the latest position seen. This allows the instance + # to report a recent position when asked, rather than a potentially old + # one (if this instance hasn't written anything for a while). + our_current_position = self._current_positions.get(self._instance_name) + if our_current_position and not self._unfinished_ids: + self._current_positions[self._instance_name] = max( + our_current_position, new_id + ) + # We move the current min position up if the minimum current positions # of all instances is higher (since by definition all positions less # that that have been persisted). diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 2dd95e2709..4386b6101e 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py
@@ -17,6 +17,7 @@ import logging import threading from typing import Callable, List, Optional +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import ( BaseDatabaseEngine, IncorrectDatabaseSetup, @@ -53,7 +54,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta): @abc.abstractmethod def check_consistency( - self, db_conn: Connection, table: str, id_column: str, positive: bool = True + self, + db_conn: LoggingDatabaseConnection, + table: str, + id_column: str, + positive: bool = True, ): """Should be called during start up to test that the current value of the sequence is greater than or equal to the maximum ID in the table. @@ -82,9 +87,13 @@ class PostgresSequenceGenerator(SequenceGenerator): return [i for (i,) in txn] def check_consistency( - self, db_conn: Connection, table: str, id_column: str, positive: bool = True + self, + db_conn: LoggingDatabaseConnection, + table: str, + id_column: str, + positive: bool = True, ): - txn = db_conn.cursor() + txn = db_conn.cursor(txn_name="sequence.check_consistency") # First we get the current max ID from the table. table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % { @@ -117,6 +126,8 @@ class PostgresSequenceGenerator(SequenceGenerator): if max_stream_id > last_value: logger.warning( "Postgres sequence %s is behind table %s: %d < %d", + self._sequence_name, + table, last_value, max_stream_id, ) diff --git a/synapse/types.py b/synapse/types.py
index bd271f9f16..c7d4e95809 100644 --- a/synapse/types.py +++ b/synapse/types.py
@@ -22,12 +22,14 @@ from typing import ( TYPE_CHECKING, Any, Dict, + Iterable, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar, + Union, ) import attr @@ -37,13 +39,14 @@ from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError if TYPE_CHECKING: + from synapse.appservice.api import ApplicationService from synapse.storage.databases.main import DataStore # define a version of typing.Collection that works on python 3.5 if sys.version_info[:3] >= (3, 6, 0): from typing import Collection else: - from typing import Container, Iterable, Sized + from typing import Container, Sized T_co = TypeVar("T_co", covariant=True) @@ -73,6 +76,7 @@ class Requester( "shadow_banned", "device_id", "app_service", + "authenticated_entity", ], ) ): @@ -103,6 +107,7 @@ class Requester( "shadow_banned": self.shadow_banned, "device_id": self.device_id, "app_server_id": self.app_service.id if self.app_service else None, + "authenticated_entity": self.authenticated_entity, } @staticmethod @@ -128,16 +133,18 @@ class Requester( shadow_banned=input["shadow_banned"], device_id=input["device_id"], app_service=appservice, + authenticated_entity=input["authenticated_entity"], ) def create_requester( - user_id, - access_token_id=None, - is_guest=False, - shadow_banned=False, - device_id=None, - app_service=None, + user_id: Union[str, "UserID"], + access_token_id: Optional[int] = None, + is_guest: Optional[bool] = False, + shadow_banned: Optional[bool] = False, + device_id: Optional[str] = None, + app_service: Optional["ApplicationService"] = None, + authenticated_entity: Optional[str] = None, ): """ Create a new ``Requester`` object @@ -150,14 +157,27 @@ def create_requester( shadow_banned (bool): True if the user making this request is shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. Returns: Requester """ if not isinstance(user_id, UserID): user_id = UserID.from_string(user_id) + + if authenticated_entity is None: + authenticated_entity = user_id.to_string() + return Requester( - user_id, access_token_id, is_guest, shadow_banned, device_id, app_service + user_id, + access_token_id, + is_guest, + shadow_banned, + device_id, + app_service, + authenticated_entity, ) @@ -297,14 +317,14 @@ mxid_localpart_allowed_characters = set( ) -def contains_invalid_mxid_characters(localpart): +def contains_invalid_mxid_characters(localpart: str) -> bool: """Check for characters not allowed in an mxid or groupid localpart Args: - localpart (basestring): the localpart to be checked + localpart: the localpart to be checked Returns: - bool: True if there are any naughty characters + True if there are any naughty characters """ return any(c not in mxid_localpart_allowed_characters for c in localpart) @@ -329,15 +349,17 @@ NON_MXID_CHARACTER_PATTERN = re.compile( ) -def map_username_to_mxid_localpart(username, case_sensitive=False): +def map_username_to_mxid_localpart( + username: Union[str, bytes], case_sensitive: bool = False +) -> str: """Map a username onto a string suitable for a MXID This follows the algorithm laid out at https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. Args: - username (unicode|bytes): username to be mapped - case_sensitive (bool): true if TEST and test should be mapped + username: username to be mapped + case_sensitive: true if TEST and test should be mapped onto different mxids Returns: @@ -375,7 +397,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False): return username.decode("ascii") -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, cmp=False) class RoomStreamToken: """Tokens are positions between events. The token "s1" comes after event 1. @@ -397,6 +419,31 @@ class RoomStreamToken: event it comes after. Historic tokens start with a "t" followed by the "topological_ordering" id of the event it comes after, followed by "-", followed by the "stream_ordering" id of the event it comes after. + + There is also a third mode for live tokens where the token starts with "m", + which is sometimes used when using sharded event persisters. In this case + the events stream is considered to be a set of streams (one for each writer) + and the token encodes the vector clock of positions of each writer in their + respective streams. + + The format of the token in such case is an initial integer min position, + followed by the mapping of instance ID to position separated by '.' and '~': + + m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ... + + The `min_pos` corresponds to the minimum position all writers have persisted + up to, and then only writers that are ahead of that position need to be + encoded. An example token is: + + m56~2.58~3.59 + + Which corresponds to a set of three (or more writers) where instances 2 and + 3 (these are instance IDs that can be looked up in the DB to fetch the more + commonly used instance names) are at positions 58 and 59 respectively, and + all other instances are at position 56. + + Note: The `RoomStreamToken` cannot have both a topological part and an + instance map. """ topological = attr.ib( @@ -405,6 +452,25 @@ class RoomStreamToken: ) stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) + instance_map = attr.ib( + type=Dict[str, int], + factory=dict, + validator=attr.validators.deep_mapping( + key_validator=attr.validators.instance_of(str), + value_validator=attr.validators.instance_of(int), + mapping_validator=attr.validators.instance_of(dict), + ), + ) + + def __attrs_post_init__(self): + """Validates that both `topological` and `instance_map` aren't set. + """ + + if self.instance_map and self.topological: + raise ValueError( + "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." + ) + @classmethod async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": try: @@ -413,6 +479,20 @@ class RoomStreamToken: if string[0] == "t": parts = string[1:].split("-", 1) return cls(topological=int(parts[0]), stream=int(parts[1])) + if string[0] == "m": + parts = string[1:].split("~") + stream = int(parts[0]) + + instance_map = {} + for part in parts[1:]: + key, value = part.split(".") + instance_id = int(key) + pos = int(value) + + instance_name = await store.get_name_from_instance_id(instance_id) + instance_map[instance_name] = pos + + return cls(topological=None, stream=stream, instance_map=instance_map,) except Exception: pass raise SynapseError(400, "Invalid token %r" % (string,)) @@ -436,14 +516,61 @@ class RoomStreamToken: max_stream = max(self.stream, other.stream) - return RoomStreamToken(None, max_stream) + instance_map = { + instance: max( + self.instance_map.get(instance, self.stream), + other.instance_map.get(instance, other.stream), + ) + for instance in set(self.instance_map).union(other.instance_map) + } + + return RoomStreamToken(None, max_stream, instance_map) + + def as_historical_tuple(self) -> Tuple[int, int]: + """Returns a tuple of `(topological, stream)` for historical tokens. + + Raises if not an historical token (i.e. doesn't have a topological part). + """ + if self.topological is None: + raise Exception( + "Cannot call `RoomStreamToken.as_historical_tuple` on live token" + ) - def as_tuple(self) -> Tuple[Optional[int], int]: return (self.topological, self.stream) + def get_stream_pos_for_instance(self, instance_name: str) -> int: + """Get the stream position that the given writer was at at this token. + + This only makes sense for "live" tokens that may have a vector clock + component, and so asserts that this is a "live" token. + """ + assert self.topological is None + + # If we don't have an entry for the instance we can assume that it was + # at `self.stream`. + return self.instance_map.get(instance_name, self.stream) + + def get_max_stream_pos(self) -> int: + """Get the maximum stream position referenced in this token. + + The corresponding "min" position is, by definition just `self.stream`. + + This is used to handle tokens that have non-empty `instance_map`, and so + reference stream positions after the `self.stream` position. + """ + return max(self.instance_map.values(), default=self.stream) + async def to_string(self, store: "DataStore") -> str: if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) + elif self.instance_map: + entries = [] + for name, pos in self.instance_map.items(): + instance_id = await store.get_id_for_instance(name) + entries.append("{}.{}".format(instance_id, pos)) + + encoded_map = "~".join(entries) + return "m{}~{}".format(self.stream, encoded_map) else: return "s%d" % (self.stream,) @@ -535,7 +662,7 @@ class PersistedEventPosition: stream = attr.ib(type=int) def persisted_after(self, token: RoomStreamToken) -> bool: - return token.stream < self.stream + return token.get_stream_pos_for_instance(self.instance_name) < self.stream def to_room_stream_token(self) -> RoomStreamToken: """Converts the position to a room stream token such that events diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index d55b93d763..517686f0a6 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py
@@ -18,6 +18,7 @@ import logging import re import attr +from frozendict import frozendict from twisted.internet import defer, task @@ -31,9 +32,26 @@ def _reject_invalid_json(val): raise ValueError("Invalid JSON value: '%s'" % val) -# Create a custom encoder to reduce the whitespace produced by JSON encoding and -# ensure that valid JSON is produced. -json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":")) +def _handle_frozendict(obj): + """Helper for json_encoder. Makes frozendicts serializable by returning + the underlying dict + """ + if type(obj) is frozendict: + # fishing the protected dict out of the object is a bit nasty, + # but we don't really want the overhead of copying the dict. + return obj._dict + raise TypeError( + "Object of type %s is not JSON serializable" % obj.__class__.__name__ + ) + + +# A custom JSON encoder which: +# * handles frozendicts +# * produces valid JSON (no NaNs etc) +# * reduces redundant whitespace +json_encoder = json.JSONEncoder( + allow_nan=False, separators=(",", ":"), default=_handle_frozendict +) # Create a custom decoder to reject Python extensions to JSON. json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 382f0cf3f0..9a873c8e8e 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -15,10 +15,12 @@ # limitations under the License. import collections +import inspect import logging from contextlib import contextmanager from typing import ( Any, + Awaitable, Callable, Dict, Hashable, @@ -542,11 +544,11 @@ class DoneAwaitable: raise StopIteration(self.value) -def maybe_awaitable(value): +def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: """Convert a value to an awaitable if not already an awaitable. """ - - if hasattr(value, "__await__"): + if inspect.isawaitable(value): + assert isinstance(value, Awaitable) return value return DoneAwaitable(value) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8fc05be278..89f0b38535 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py
@@ -16,7 +16,7 @@ import logging from sys import intern -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Sized import attr from prometheus_client.core import Gauge @@ -92,7 +92,7 @@ class CacheMetric: def register_cache( cache_type: str, cache_name: str, - cache, + cache: Sized, collect_callback: Optional[Callable] = None, resizable: bool = True, resize_callback: Optional[Callable] = None, @@ -100,12 +100,15 @@ def register_cache( """Register a cache object for metric collection and resizing. Args: - cache_type + cache_type: a string indicating the "type" of the cache. This is used + only for deduplication so isn't too important provided it's constant. cache_name: name of the cache - cache: cache itself + cache: cache itself, which must implement __len__(), and may optionally implement + a max_size property collect_callback: If given, a function which is called during metric collection to update additional metrics. - resizable: Whether this cache supports being resized. + resizable: Whether this cache supports being resized, in which case either + resize_callback must be provided, or the cache must support set_max_size(). resize_callback: A function which can be called to resize the cache. Returns: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py new file mode 100644
index 0000000000..601305487c --- /dev/null +++ b/synapse/util/caches/deferred_cache.py
@@ -0,0 +1,342 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import threading +from typing import ( + Callable, + Generic, + Iterable, + MutableMapping, + Optional, + TypeVar, + Union, + cast, +) + +from prometheus_client import Gauge + +from twisted.internet import defer +from twisted.python import failure + +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry + +cache_pending_metric = Gauge( + "synapse_util_caches_cache_pending", + "Number of lookups currently pending for this cache", + ["name"], +) + +T = TypeVar("T") +KT = TypeVar("KT") +VT = TypeVar("VT") + + +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() + + +class DeferredCache(Generic[KT, VT]): + """Wraps an LruCache, adding support for Deferred results. + + It expects that each entry added with set() will be a Deferred; likewise get() + will return a Deferred. + """ + + __slots__ = ( + "cache", + "thread", + "_pending_deferred_cache", + ) + + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key. Ignored unless + `tree` is True. + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + """ + cache_type = TreeCache if tree else dict + + # _pending_deferred_cache maps from the key value to a `CacheEntry` object. + self._pending_deferred_cache = ( + cache_type() + ) # type: MutableMapping[KT, CacheEntry] + + def metrics_cb(): + cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) + + # cache is used for completed results and maps to the result itself, rather than + # a Deferred. + self.cache = LruCache( + max_size=max_entries, + keylen=keylen, + cache_name=name, + cache_type=cache_type, + size_callback=(lambda d: len(d)) if iterable else None, + metrics_collection_callback=metrics_cb, + apply_cache_factor_from_config=apply_cache_factor_from_config, + ) # type: LruCache[KT, VT] + + self.thread = None # type: Optional[threading.Thread] + + @property + def max_entries(self): + return self.cache.max_size + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get( + self, + key: KT, + callback: Optional[Callable[[], None]] = None, + update_metrics: bool = True, + ) -> defer.Deferred: + """Looks the key up in the caches. + + For symmetry with set(), this method does *not* follow the synapse logcontext + rules: the logcontext will not be cleared on return, and the Deferred will run + its callbacks in the sentinel context. In other words: wrap the result with + make_deferred_yieldable() before `await`ing it. + + Args: + key: + callback: Gets called when the entry in the cache is invalidated + update_metrics (bool): whether to update the cache hit rate metrics + + Returns: + A Deferred which completes with the result. Note that this may later fail + if there is an ongoing set() operation which later completes with a failure. + + Raises: + KeyError if the key is not found in the cache + """ + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if val is not _Sentinel.sentinel: + val.callbacks.update(callbacks) + if update_metrics: + m = self.cache.metrics + assert m # we always have a name, so should always have metrics + m.inc_hits() + return val.deferred.observe() + + val2 = self.cache.get( + key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics + ) + if val2 is _Sentinel.sentinel: + raise KeyError() + else: + return defer.succeed(val2) + + def get_immediate( + self, key: KT, default: T, update_metrics: bool = True + ) -> Union[VT, T]: + """If we have a *completed* cached value, return it.""" + return self.cache.get(key, default, update_metrics=update_metrics) + + def set( + self, + key: KT, + value: defer.Deferred, + callback: Optional[Callable[[], None]] = None, + ) -> defer.Deferred: + """Adds a new entry to the cache (or updates an existing one). + + The given `value` *must* be a Deferred. + + First any existing entry for the same key is invalidated. Then a new entry + is added to the cache for the given key. + + Until the `value` completes, calls to `get()` for the key will also result in an + incomplete Deferred, which will ultimately complete with the same result as + `value`. + + If `value` completes successfully, subsequent calls to `get()` will then return + a completed deferred with the same result. If it *fails*, the cache is + invalidated and subequent calls to `get()` will raise a KeyError. + + If another call to `set()` happens before `value` completes, then (a) any + invalidation callbacks registered in the interim will be called, (b) any + `get()`s in the interim will continue to complete with the result from the + *original* `value`, (c) any future calls to `get()` will complete with the + result from the *new* `value`. + + It is expected that `value` does *not* follow the synapse logcontext rules - ie, + if it is incomplete, it runs its callbacks in the sentinel context. + + Args: + key: Key to be set + value: a deferred which will complete with a result to add to the cache + callback: An optional callback to be called when the entry is invalidated + """ + if not isinstance(value, defer.Deferred): + raise TypeError("not a Deferred") + + callbacks = [callback] if callback else [] + self.check_thread() + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + # XXX: why don't we invalidate the entry in `self.cache` yet? + + # we can save a whole load of effort if the deferred is ready. + if value.called: + result = value.result + if not isinstance(result, failure.Failure): + self.cache.set(key, result, callbacks) + return value + + # otherwise, we'll add an entry to the _pending_deferred_cache for now, + # and add callbacks to add it to the cache properly later. + + observable = ObservableDeferred(value, consumeErrors=True) + observer = observable.observe() + entry = CacheEntry(deferred=observable, callbacks=callbacks) + + self._pending_deferred_cache[key] = entry + + def compare_and_pop(): + """Check if our entry is still the one in _pending_deferred_cache, and + if so, pop it. + + Returns true if the entries matched. + """ + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + return True + + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + return False + + def cb(result): + if compare_and_pop(): + self.cache.set(key, result, entry.callbacks) + else: + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. + entry.invalidate() + + def eb(_fail): + compare_and_pop() + entry.invalidate() + + # once the deferred completes, we can move the entry from the + # _pending_deferred_cache to the real cache. + # + observer.addCallbacks(cb, eb) + + # we return a new Deferred which will be called before any subsequent observers. + return observable.observe() + + def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) + + def invalidate(self, key): + self.check_thread() + self.cache.pop(key, None) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. + entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. + if entry: + entry.invalidate() + + def invalidate_many(self, key: KT): + self.check_thread() + if not isinstance(key, tuple): + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + key = cast(KT, key) + self.cache.del_multi(key) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above + entry_dict = self._pending_deferred_cache.pop(key, None) + if entry_dict is not None: + for entry in iterate_tree_cache_entry(entry_dict): + entry.invalidate() + + def invalidate_all(self): + self.check_thread() + self.cache.clear() + for entry in self._pending_deferred_cache.values(): + entry.invalidate() + self._pending_deferred_cache.clear() + + +class CacheEntry: + __slots__ = ["deferred", "callbacks", "invalidated"] + + def __init__( + self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] + ): + self.deferred = deferred + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 98b34f2223..a924140cdf 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py
@@ -13,25 +13,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import functools import inspect import logging -import threading -from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Generic, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) from weakref import WeakValueDictionary -from prometheus_client import Gauge - from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError -from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry - -from . import register_cache logger = logging.getLogger(__name__) @@ -55,241 +61,8 @@ class _CachedFunction(Generic[F]): __call__ = None # type: F -cache_pending_metric = Gauge( - "synapse_util_caches_cache_pending", - "Number of lookups currently pending for this cache", - ["name"], -) - -_CacheSentinel = object() - - -class CacheEntry: - __slots__ = ["deferred", "callbacks", "invalidated"] - - def __init__(self, deferred, callbacks): - self.deferred = deferred - self.callbacks = set(callbacks) - self.invalidated = False - - def invalidate(self): - if not self.invalidated: - self.invalidated = True - for callback in self.callbacks: - callback() - self.callbacks.clear() - - -class Cache: - __slots__ = ( - "cache", - "name", - "keylen", - "thread", - "metrics", - "_pending_deferred_cache", - ) - - def __init__( - self, - name: str, - max_entries: int = 1000, - keylen: int = 1, - tree: bool = False, - iterable: bool = False, - apply_cache_factor_from_config: bool = True, - ): - """ - Args: - name: The name of the cache - max_entries: Maximum amount of entries that the cache will hold - keylen: The length of the tuple used as the cache key - tree: Use a TreeCache instead of a dict as the underlying cache type - iterable: If True, count each item in the cached object as an entry, - rather than each cached object - apply_cache_factor_from_config: Whether cache factors specified in the - config file affect `max_entries` - - Returns: - Cache - """ - cache_type = TreeCache if tree else dict - self._pending_deferred_cache = cache_type() - - self.cache = LruCache( - max_size=max_entries, - keylen=keylen, - cache_type=cache_type, - size_callback=(lambda d: len(d)) if iterable else None, - evicted_callback=self._on_evicted, - apply_cache_factor_from_config=apply_cache_factor_from_config, - ) - - self.name = name - self.keylen = keylen - self.thread = None # type: Optional[threading.Thread] - self.metrics = register_cache( - "cache", - name, - self.cache, - collect_callback=self._metrics_collection_callback, - ) - - @property - def max_entries(self): - return self.cache.max_size - - def _on_evicted(self, evicted_count): - self.metrics.inc_evictions(evicted_count) - - def _metrics_collection_callback(self): - cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True): - """Looks the key up in the caches. - - Args: - key(tuple) - default: What is returned if key is not in the caches. If not - specified then function throws KeyError instead - callback(fn): Gets called when the entry in the cache is invalidated - update_metrics (bool): whether to update the cache hit rate metrics - - Returns: - Either an ObservableDeferred or the raw result - """ - callbacks = [callback] if callback else [] - val = self._pending_deferred_cache.get(key, _CacheSentinel) - if val is not _CacheSentinel: - val.callbacks.update(callbacks) - if update_metrics: - self.metrics.inc_hits() - return val.deferred - - val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) - if val is not _CacheSentinel: - self.metrics.inc_hits() - return val - - if update_metrics: - self.metrics.inc_misses() - - if default is _CacheSentinel: - raise KeyError() - else: - return default - - def set(self, key, value, callback=None): - if not isinstance(value, defer.Deferred): - raise TypeError("not a Deferred") - - callbacks = [callback] if callback else [] - self.check_thread() - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) - - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry: - existing_entry.invalidate() - - self._pending_deferred_cache[key] = entry - - def compare_and_pop(): - """Check if our entry is still the one in _pending_deferred_cache, and - if so, pop it. - - Returns true if the entries matched. - """ - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - return True - - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - - return False - - def cb(result): - if compare_and_pop(): - self.cache.set(key, result, entry.callbacks) - else: - # we're not going to put this entry into the cache, so need - # to make sure that the invalidation callbacks are called. - # That was probably done when _pending_deferred_cache was - # updated, but it's possible that `set` was called without - # `invalidate` being previously called, in which case it may - # not have been. Either way, let's double-check now. - entry.invalidate() - - def eb(_fail): - compare_and_pop() - entry.invalidate() - - # once the deferred completes, we can move the entry from the - # _pending_deferred_cache to the real cache. - # - observer.addCallbacks(cb, eb) - return observable - - def prefill(self, key, value, callback=None): - callbacks = [callback] if callback else [] - self.cache.set(key, value, callbacks=callbacks) - - def invalidate(self, key): - self.check_thread() - self.cache.pop(key, None) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, which will (a) stop it being returned - # for future queries and (b) stop it being persisted as a proper entry - # in self.cache. - entry = self._pending_deferred_cache.pop(key, None) - - # run the invalidation callbacks now, rather than waiting for the - # deferred to resolve. - if entry: - entry.invalidate() - - def invalidate_many(self, key): - self.check_thread() - if not isinstance(key, tuple): - raise TypeError("The cache key must be a tuple not %r" % (type(key),)) - self.cache.del_multi(key) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(key, None) - if entry_dict is not None: - for entry in iterate_tree_cache_entry(entry_dict): - entry.invalidate() - - def invalidate_all(self): - self.check_thread() - self.cache.clear() - for entry in self._pending_deferred_cache.values(): - entry.invalidate() - self._pending_deferred_cache.clear() - - class _CacheDescriptorBase: - def __init__(self, orig: _CachedFunction, num_args, cache_context=False): + def __init__(self, orig: Callable[..., Any], num_args, cache_context=False): self.orig = orig arg_spec = inspect.getfullargspec(orig) @@ -338,8 +111,107 @@ class _CacheDescriptorBase: self.add_cache_context = cache_context + self.cache_key_builder = get_cache_key_builder( + self.arg_names, self.arg_defaults + ) + + +class _LruCachedFunction(Generic[F]): + cache = None # type: LruCache[CacheKey, Any] + __call__ = None # type: F + + +def lru_cache( + max_entries: int = 1000, cache_context: bool = False, +) -> Callable[[F], _LruCachedFunction[F]]: + """A method decorator that applies a memoizing cache around the function. + + This is more-or-less a drop-in equivalent to functools.lru_cache, although note + that the signature is slightly different. + + The main differences with functools.lru_cache are: + (a) the size of the cache can be controlled via the cache_factor mechanism + (b) the wrapped function can request a "cache_context" which provides a + callback mechanism to indicate that the result is no longer valid + (c) prometheus metrics are exposed automatically. + + The function should take zero or more arguments, which are used as the key for the + cache. Single-argument functions use that argument as the cache key; otherwise the + arguments are built into a tuple. + + Cached functions can be "chained" (i.e. a cached function can call other cached + functions and get appropriately invalidated when they called caches are + invalidated) by adding a special "cache_context" argument to the function + and passing that as a kwarg to all caches called. For example: + + @lru_cache(cache_context=True) + def foo(self, key, cache_context): + r1 = self.bar1(key, on_invalidate=cache_context.invalidate) + r2 = self.bar2(key, on_invalidate=cache_context.invalidate) + return r1 + r2 + + The wrapped function also has a 'cache' property which offers direct access to the + underlying LruCache. + """ + + def func(orig: F) -> _LruCachedFunction[F]: + desc = LruCacheDescriptor( + orig, max_entries=max_entries, cache_context=cache_context, + ) + return cast(_LruCachedFunction[F], desc) + + return func + + +class LruCacheDescriptor(_CacheDescriptorBase): + """Helper for @lru_cache""" -class CacheDescriptor(_CacheDescriptorBase): + class _Sentinel(enum.Enum): + sentinel = object() + + def __init__( + self, orig, max_entries: int = 1000, cache_context: bool = False, + ): + super().__init__(orig, num_args=None, cache_context=cache_context) + self.max_entries = max_entries + + def __get__(self, obj, owner): + cache = LruCache( + cache_name=self.orig.__name__, max_size=self.max_entries, + ) # type: LruCache[CacheKey, Any] + + get_cache_key = self.cache_key_builder + sentinel = LruCacheDescriptor._Sentinel.sentinel + + @functools.wraps(self.orig) + def _wrapped(*args, **kwargs): + invalidate_callback = kwargs.pop("on_invalidate", None) + callbacks = (invalidate_callback,) if invalidate_callback else () + + cache_key = get_cache_key(args, kwargs) + + ret = cache.get(cache_key, default=sentinel, callbacks=callbacks) + if ret != sentinel: + return ret + + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) + + ret2 = self.orig(obj, *args, **kwargs) + cache.set(cache_key, ret2, callbacks=callbacks) + + return ret2 + + wrapped = cast(_CachedFunction, _wrapped) + wrapped.cache = cache + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +class DeferredCacheDescriptor(_CacheDescriptorBase): """ A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that @@ -382,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=False, iterable=False, ): - super().__init__(orig, num_args=num_args, cache_context=cache_context) self.max_entries = max_entries @@ -390,49 +261,15 @@ class CacheDescriptor(_CacheDescriptorBase): self.iterable = iterable def __get__(self, obj, owner): - cache = Cache( + cache = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, iterable=self.iterable, - ) - - def get_cache_key_gen(args, kwargs): - """Given some args/kwargs return a generator that resolves into - the cache_key. - - We loop through each arg name, looking up if its in the `kwargs`, - otherwise using the next argument in `args`. If there are no more - args then we try looking the arg name up in the defaults - """ - pos = 0 - for nm in self.arg_names: - if nm in kwargs: - yield kwargs[nm] - elif pos < len(args): - yield args[pos] - pos += 1 - else: - yield self.arg_defaults[nm] - - # By default our cache key is a tuple, but if there is only one item - # then don't bother wrapping in a tuple. This is to save memory. - if self.num_args == 1: - nm = self.arg_names[0] - - def get_cache_key(args, kwargs): - if nm in kwargs: - return kwargs[nm] - elif len(args): - return args[0] - else: - return self.arg_defaults[nm] - - else: + ) # type: DeferredCache[CacheKey, Any] - def get_cache_key(args, kwargs): - return tuple(get_cache_key_gen(args, kwargs)) + get_cache_key = self.cache_key_builder @functools.wraps(self.orig) def _wrapped(*args, **kwargs): @@ -442,32 +279,20 @@ class CacheDescriptor(_CacheDescriptorBase): cache_key = get_cache_key(args, kwargs) - # Add our own `cache_context` to argument list if the wrapped function - # has asked for one - if self.add_cache_context: - kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) - try: - cached_result_d = cache.get(cache_key, callback=invalidate_callback) - - if isinstance(cached_result_d, ObservableDeferred): - observer = cached_result_d.observe() - else: - observer = defer.succeed(cached_result_d) - + ret = cache.get(cache_key, callback=invalidate_callback) except KeyError: - ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext.get_instance( + cache, cache_key + ) - def onErr(f): - cache.invalidate(cache_key) - return f - - ret.addErrback(onErr) - - result_d = cache.set(cache_key, ret, callback=invalidate_callback) - observer = result_d.observe() + ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) + ret = cache.set(cache_key, ret, callback=invalidate_callback) - return make_deferred_yieldable(observer) + return make_deferred_yieldable(ret) wrapped = cast(_CachedFunction, _wrapped) @@ -476,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase): wrapped.prefill = lambda key, val: cache.prefill(key[0], val) else: wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_many = cache.invalidate_many wrapped.prefill = cache.prefill @@ -489,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase): return wrapped -class CacheListDescriptor(_CacheDescriptorBase): +class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given a list of keys it looks in the cache to find any hits, then passes @@ -526,7 +350,7 @@ class CacheListDescriptor(_CacheDescriptorBase): def __get__(self, obj, objtype=None): cached_method = getattr(obj, self.cached_method_name) - cache = cached_method.cache + cache = cached_method.cache # type: DeferredCache[CacheKey, Any] num_args = cached_method.num_args @functools.wraps(self.orig) @@ -566,14 +390,11 @@ class CacheListDescriptor(_CacheDescriptorBase): for arg in list_args: try: res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) - if not isinstance(res, ObservableDeferred): - results[arg] = res - elif not res.has_succeeded(): - res = res.observe() + if not res.called: res.addCallback(update_results_dict, arg) cached_defers.append(res) else: - results[arg] = res.get_result() + results[arg] = res.result except KeyError: missing.add(arg) @@ -638,11 +459,13 @@ class _CacheContext: on a lower level. """ + Cache = Union[DeferredCache, LruCache] + _cache_context_objects = ( WeakValueDictionary() - ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext] + ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext] - def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None + def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None: self._cache = cache self._cache_key = cache_key @@ -651,7 +474,9 @@ class _CacheContext: self._cache.invalidate(self._cache_key) @classmethod - def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext + def get_instance( + cls, cache: "_CacheContext.Cache", cache_key: CacheKey + ) -> "_CacheContext": """Returns an instance constructed with the given arguments. A new instance is only created if none already exists. @@ -672,7 +497,7 @@ def cached( cache_context: bool = False, iterable: bool = False, ) -> Callable[[F], _CachedFunction[F]]: - func = lambda orig: CacheDescriptor( + func = lambda orig: DeferredCacheDescriptor( orig, max_entries=max_entries, num_args=num_args, @@ -714,7 +539,7 @@ def cachedList( def batch_do_something(self, first_arg, second_args): ... """ - func = lambda orig: CacheListDescriptor( + func = lambda orig: DeferredCacheListDescriptor( orig, cached_method_name=cached_method_name, list_name=list_name, @@ -722,3 +547,65 @@ def cachedList( ) return cast(Callable[[F], _CachedFunction[F]], func) + + +def get_cache_key_builder( + param_names: Sequence[str], param_defaults: Mapping[str, Any] +) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: + """Construct a function which will build cache keys suitable for a cached function + + Args: + param_names: list of formal parameter names for the cached function + param_defaults: a mapping from parameter name to default value for that param + + Returns: + A function which will take an (args, kwargs) pair and return a cache key + """ + + # By default our cache key is a tuple, but if there is only one item + # then don't bother wrapping in a tuple. This is to save memory. + + if len(param_names) == 1: + nm = param_names[0] + + def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: + if nm in kwargs: + return kwargs[nm] + elif len(args): + return args[0] + else: + return param_defaults[nm] + + else: + + def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: + return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs)) + + return get_cache_key + + +def _get_cache_key_gen( + param_names: Iterable[str], + param_defaults: Mapping[str, Any], + args: Sequence[Any], + kwargs: Mapping[str, Any], +) -> Iterable[Any]: + """Given some args/kwargs return a generator that resolves into + the cache_key. + + This is essentially the same operation as `inspect.getcallargs`, but optimised so + that we don't need to inspect the target function for each call. + """ + + # We loop through each arg name, looking up if its in the `kwargs`, + # otherwise using the next argument in `args`. If there are no more + # args then we try looking the arg name up in the defaults. + pos = 0 + for nm in param_names: + if nm in kwargs: + yield kwargs[nm] + elif pos < len(args): + yield args[pos] + pos += 1 + else: + yield param_defaults[nm] diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 8592b93689..588d2d49f2 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py
@@ -12,15 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import logging import threading from collections import namedtuple +from typing import Any from synapse.util.caches.lrucache import LruCache -from . import register_cache - logger = logging.getLogger(__name__) @@ -40,24 +39,25 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va return len(self.value) +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() + + class DictionaryCache: """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries, size_callback=len) + self.cache = LruCache( + max_size=max_entries, cache_name=name, size_callback=len + ) # type: LruCache[Any, DictionaryEntry] self.name = name self.sequence = 0 self.thread = None - # caches_by_name[name] = self.cache - - class Sentinel: - __slots__ = [] - - self.sentinel = Sentinel() - self.metrics = register_cache("dictionary", name, self.cache) def check_thread(self): expected_thread = self.thread @@ -80,10 +80,8 @@ class DictionaryCache: Returns: DictionaryEntry """ - entry = self.cache.get(key, self.sentinel) - if entry is not self.sentinel: - self.metrics.inc_hits() - + entry = self.cache.get(key, _Sentinel.sentinel) + if entry is not _Sentinel.sentinel: if dict_keys is None: return DictionaryEntry( entry.full, entry.known_absent, dict(entry.value) @@ -95,7 +93,6 @@ class DictionaryCache: {k: entry.value[k] for k in dict_keys if k in entry.value}, ) - self.metrics.inc_misses() return DictionaryEntry(False, set(), {}) def invalidate(self, key): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4bc1a67b58..60bb6ff642 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py
@@ -15,11 +15,35 @@ import threading from functools import wraps -from typing import Callable, Optional, Type, Union +from typing import ( + Any, + Callable, + Generic, + Iterable, + Optional, + Type, + TypeVar, + Union, + cast, + overload, +) + +from typing_extensions import Literal from synapse.config import cache as cache_config +from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.treecache import TreeCache +# Function type: the type used for invalidation callbacks +FT = TypeVar("FT", bound=Callable[..., Any]) + +# Key and Value type for the cache +KT = TypeVar("KT") +VT = TypeVar("VT") + +# a general type var, distinct from either KT or VT +T = TypeVar("T") + def enumerate_leaves(node, depth): if depth == 0: @@ -41,30 +65,33 @@ class _Node: self.callbacks = callbacks -class LruCache: +class LruCache(Generic[KT, VT]): """ - Least-recently-used cache. + Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. + Supports del_multi only if cache_type=TreeCache If cache_type=TreeCache, all keys must be tuples. - - Can also set callbacks on objects when getting/setting which are fired - when that key gets invalidated/evicted. """ def __init__( self, max_size: int, + cache_name: Optional[str] = None, keylen: int = 1, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable] = None, - evicted_callback: Optional[Callable] = None, + metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, ): """ Args: max_size: The maximum amount of entries the cache can hold - keylen: The length of the tuple used as the cache key + cache_name: The name of this cache, for the prometheus metrics. If unset, + no metrics will be reported on this cache. + + keylen: The length of the tuple used as the cache key. Ignored unless + cache_type is `TreeCache`. cache_type (type): type of underlying cache to be used. Typically one of dict @@ -72,9 +99,13 @@ class LruCache: size_callback (func(V) -> int | None): - evicted_callback (func(int)|None): - if not None, called on eviction with the size of the evicted - entry + metrics_collection_callback: + metrics collection callback. This is called early in the metrics + collection process, before any of the metrics registered with the + prometheus Registry are collected, so can be used to update any dynamic + metrics. + + Ignored if cache_name is None. apply_cache_factor_from_config (bool): If true, `max_size` will be multiplied by a cache factor derived from the homeserver config @@ -93,6 +124,23 @@ class LruCache: else: self.max_size = int(max_size) + # register_cache might call our "set_cache_factor" callback; there's nothing to + # do yet when we get resized. + self._on_resize = None # type: Optional[Callable[[],None]] + + if cache_name is not None: + metrics = register_cache( + "lru_cache", + cache_name, + self, + collect_callback=metrics_collection_callback, + ) # type: Optional[CacheMetric] + else: + metrics = None + + # this is exposed for access from outside this class + self.metrics = metrics + list_root = _Node(None, None, None, None) list_root.next_node = list_root list_root.prev_node = list_root @@ -104,16 +152,16 @@ class LruCache: todelete = list_root.prev_node evicted_len = delete_node(todelete) cache.pop(todelete.key, None) - if evicted_callback: - evicted_callback(evicted_len) + if metrics: + metrics.inc_evictions(evicted_len) - def synchronized(f): + def synchronized(f: FT) -> FT: @wraps(f) def inner(*args, **kwargs): with lock: return f(*args, **kwargs) - return inner + return cast(FT, inner) cached_cache_len = [0] if size_callback is not None: @@ -167,18 +215,45 @@ class LruCache: node.callbacks.clear() return deleted_len + @overload + def cache_get( + key: KT, + default: Literal[None] = None, + callbacks: Iterable[Callable[[], None]] = ..., + update_metrics: bool = ..., + ) -> Optional[VT]: + ... + + @overload + def cache_get( + key: KT, + default: T, + callbacks: Iterable[Callable[[], None]] = ..., + update_metrics: bool = ..., + ) -> Union[T, VT]: + ... + @synchronized - def cache_get(key, default=None, callbacks=[]): + def cache_get( + key: KT, + default: Optional[T] = None, + callbacks: Iterable[Callable[[], None]] = [], + update_metrics: bool = True, + ): node = cache.get(key, None) if node is not None: move_node_to_front(node) node.callbacks.update(callbacks) + if update_metrics and metrics: + metrics.inc_hits() return node.value else: + if update_metrics and metrics: + metrics.inc_misses() return default @synchronized - def cache_set(key, value, callbacks=[]): + def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []): node = cache.get(key, None) if node is not None: # We sometimes store large objects, e.g. dicts, which cause @@ -207,7 +282,7 @@ class LruCache: evict() @synchronized - def cache_set_default(key, value): + def cache_set_default(key: KT, value: VT) -> VT: node = cache.get(key, None) if node is not None: return node.value @@ -216,8 +291,16 @@ class LruCache: evict() return value + @overload + def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]: + ... + + @overload + def cache_pop(key: KT, default: T) -> Union[T, VT]: + ... + @synchronized - def cache_pop(key, default=None): + def cache_pop(key: KT, default: Optional[T] = None): node = cache.get(key, None) if node: delete_node(node) @@ -227,18 +310,18 @@ class LruCache: return default @synchronized - def cache_del_multi(key): + def cache_del_multi(key: KT) -> None: """ This will only work if constructed with cache_type=TreeCache """ popped = cache.pop(key) if popped is None: return - for leaf in enumerate_leaves(popped, keylen - len(key)): + for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))): delete_node(leaf) @synchronized - def cache_clear(): + def cache_clear() -> None: list_root.next_node = list_root list_root.prev_node = list_root for node in cache.values(): @@ -249,15 +332,21 @@ class LruCache: cached_cache_len[0] = 0 @synchronized - def cache_contains(key): + def cache_contains(key: KT) -> bool: return key in cache self.sentinel = object() + + # make sure that we clear out any excess entries after we get resized. self._on_resize = evict + self.get = cache_get self.set = cache_set self.setdefault = cache_set_default self.pop = cache_pop + # `invalidate` is exposed for consistency with DeferredCache, so that it can be + # invalidated by the cache invalidation replication stream. + self.invalidate = cache_pop if cache_type is TreeCache: self.del_multi = cache_del_multi self.len = synchronized(cache_len) @@ -301,6 +390,7 @@ class LruCache: new_size = int(self._original_max_size * factor) if new_size != self.max_size: self.max_size = new_size - self._on_resize() + if self._on_resize: + self._on_resize() return True return False diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index df1a721add..32228f42ee 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py
@@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar from twisted.internet import defer @@ -20,10 +21,15 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) +T = TypeVar("T") + -class ResponseCache: +class ResponseCache(Generic[T]): """ This caches a deferred response. Until the deferred completes it will be returned from the cache. This means that if the client retries the request @@ -31,8 +37,9 @@ class ResponseCache: used rather than trying to compute a new response. """ - def __init__(self, hs, name, timeout_ms=0): - self.pending_result_cache = {} # Requests that haven't finished yet. + def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): + # Requests that haven't finished yet. + self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] self.clock = hs.get_clock() self.timeout_sec = timeout_ms / 1000.0 @@ -40,13 +47,13 @@ class ResponseCache: self._name = name self._metrics = register_cache("response_cache", name, self, resizable=False) - def size(self): + def size(self) -> int: return len(self.pending_result_cache) - def __len__(self): + def __len__(self) -> int: return self.size() - def get(self, key): + def get(self, key: T) -> Optional[defer.Deferred]: """Look up the given key. Can return either a new Deferred (which also doesn't follow the synapse @@ -58,12 +65,11 @@ class ResponseCache: from an absent cache entry. Args: - key (hashable): + key: key to get/set in the cache Returns: - twisted.internet.defer.Deferred|None|E: None if there is no entry - for this key; otherwise either a deferred result or the result - itself. + None if there is no entry for this key; otherwise a deferred which + resolves to the result. """ result = self.pending_result_cache.get(key) if result is not None: @@ -73,7 +79,7 @@ class ResponseCache: self._metrics.inc_misses() return None - def set(self, key, deferred): + def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred: """Set the entry for the given key to the given deferred. *deferred* should run its callbacks in the sentinel logcontext (ie, @@ -85,12 +91,11 @@ class ResponseCache: result. You will probably want to make_deferred_yieldable the result. Args: - key (hashable): - deferred (twisted.internet.defer.Deferred[T): + key: key to get/set in the cache + deferred: The deferred which resolves to the result. Returns: - twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual - result. + A new deferred which resolves to the actual result. """ result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result @@ -107,7 +112,9 @@ class ResponseCache: result.addBoth(remove) return result.observe() - def wrap(self, key, callback, *args, **kwargs): + def wrap( + self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any + ) -> defer.Deferred: """Wrap together a *get* and *set* call, taking care of logcontexts First looks up the key in the cache, and if it is present makes it @@ -118,21 +125,20 @@ class ResponseCache: Example usage: - @defer.inlineCallbacks - def handle_request(request): + async def handle_request(request): # etc return result - result = yield response_cache.wrap( + result = await response_cache.wrap( key, handle_request, request, ) Args: - key (hashable): key to get/set in the cache + key: key to get/set in the cache - callback (callable): function to call if the key is not found in + callback: function to call if the key is not found in the cache *args: positional parameters to pass to the callback, if it is used @@ -140,7 +146,7 @@ class ResponseCache: **kwargs: named parameters to pass to the callback, if it is used Returns: - twisted.internet.defer.Deferred: yieldable result + Deferred which resolves to the result """ result = self.get(key) if not result: diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 3e180cafd3..6ce2a3d12b 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py
@@ -34,7 +34,7 @@ class TTLCache: self._data = {} # the _CacheEntries, sorted by expiry time - self._expiry_list = SortedList() + self._expiry_list = SortedList() # type: SortedList[_CacheEntry] self._timer = timer diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index f73e95393c..a6ee9edaec 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py
@@ -12,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import logging from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -105,10 +105,7 @@ class Signal: async def do(observer): try: - result = observer(*args, **kwargs) - if inspect.isawaitable(result): - result = await result - return result + return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: logger.warning( "%s signal observer %s failed: %r", self.name, observer, e, diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index bf094c9386..5f7a6dd1d3 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py
@@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json - from frozendict import frozendict @@ -49,23 +47,3 @@ def unfreeze(o): pass return o - - -def _handle_frozendict(obj): - """Helper for EventEncoder. Makes frozendicts serializable by returning - the underlying dict - """ - if type(obj) is frozendict: - # fishing the protected dict out of the object is a bit nasty, - # but we don't really want the overhead of copying the dict. - return obj._dict - raise TypeError( - "Object of type %s is not JSON serializable" % obj.__class__.__name__ - ) - - -# A JSONEncoder which is capable of encoding frozendicts without barfing. -# Additionally reduce the whitespace produced by JSON encoding. -frozendict_json_encoder = json.JSONEncoder( - allow_nan=False, separators=(",", ":"), default=_handle_frozendict, -) diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index bb62db4637..1ee61851e4 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py
@@ -15,28 +15,56 @@ import importlib import importlib.util +import itertools +from typing import Any, Iterable, Tuple, Type + +import jsonschema from synapse.config._base import ConfigError +from synapse.config._util import json_error_to_config_error -def load_module(provider): +def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: """ Loads a synapse module with its config - Take a dict with keys 'module' (the module name) and 'config' - (the config dict). + + Args: + provider: a dict with keys 'module' (the module name) and 'config' + (the config dict). + config_path: the path within the config file. This will be used as a basis + for any error message. Returns Tuple of (provider class, parsed config object) """ + + modulename = provider.get("module") + if not isinstance(modulename, str): + raise ConfigError( + "expected a string", path=itertools.chain(config_path, ("module",)) + ) + # We need to import the module, and then pick the class out of # that, so we split based on the last dot. - module, clz = provider["module"].rsplit(".", 1) + module, clz = modulename.rsplit(".", 1) module = importlib.import_module(module) provider_class = getattr(module, clz) + module_config = provider.get("config") try: - provider_config = provider_class.parse_config(provider.get("config")) + provider_config = provider_class.parse_config(module_config) + except jsonschema.ValidationError as e: + raise json_error_to_config_error(e, itertools.chain(config_path, ("config",))) + except ConfigError as e: + raise _wrap_config_error( + "Failed to parse config for module %r" % (modulename,), + prefix=itertools.chain(config_path, ("config",)), + e=e, + ) except Exception as e: - raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e)) + raise ConfigError( + "Failed to parse config for module %r" % (modulename,), + path=itertools.chain(config_path, ("config",)), + ) from e return provider_class, provider_config @@ -56,3 +84,27 @@ def load_python_module(location: str): mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) # type: ignore return mod + + +def _wrap_config_error( + msg: str, prefix: Iterable[str], e: ConfigError +) -> "ConfigError": + """Wrap a relative ConfigError with a new path + + This is useful when we have a ConfigError with a relative path due to a problem + parsing part of the config, and we now need to set it in context. + """ + path = prefix + if e.path: + path = itertools.chain(prefix, e.path) + + e1 = ConfigError(msg, path) + + # ideally we would set the 'cause' of the new exception to the original exception; + # however now that we have merged the path into our own, the stringification of + # e will be incorrect, so instead we create a new exception with just the "msg" + # part. + + e1.__cause__ = Exception(e.msg) + e1.__cause__.__cause__ = e.__cause__ + return e1 diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index a5cc9d0551..4ab379e429 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py
@@ -110,7 +110,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k failure_ts, retry_interval, backoff_on_failure=backoff_on_failure, - **kwargs + **kwargs, ) diff --git a/synapse/visibility.py b/synapse/visibility.py
index e3da7744d2..ec50e7e977 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py
@@ -12,11 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import operator -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + EventTypes, + HistoryVisibility, + Membership, +) from synapse.events.utils import prune_event from synapse.storage import Storage from synapse.storage.state import StateFilter @@ -25,7 +29,12 @@ from synapse.types import get_domain_from_id logger = logging.getLogger(__name__) -VISIBILITY_PRIORITY = ("world_readable", "shared", "invited", "joined") +VISIBILITY_PRIORITY = ( + HistoryVisibility.WORLD_READABLE, + HistoryVisibility.SHARED, + HistoryVisibility.INVITED, + HistoryVisibility.JOINED, +) MEMBERSHIP_PRIORITY = ( @@ -77,15 +86,14 @@ async def filter_events_for_client( ) ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id + AccountDataTypes.IGNORED_USER_LIST, user_id ) - # FIXME: This will explode if people upload something incorrect. - ignore_list = frozenset( - ignore_dict_content.get("ignored_users", {}).keys() - if ignore_dict_content - else [] - ) + ignore_list = frozenset() + if ignore_dict_content: + ignored_users_dict = ignore_dict_content.get("ignored_users", {}) + if isinstance(ignored_users_dict, dict): + ignore_list = frozenset(ignored_users_dict.keys()) erased_senders = await storage.main.are_users_erased((e.sender for e in events)) @@ -117,7 +125,7 @@ async def filter_events_for_client( # see events in the room at that point in the DAG, and that shouldn't be decided # on those checks. if filter_send_to_client: - if event.type == "org.matrix.dummy_event": + if event.type == EventTypes.Dummy: return None if not event.is_state() and event.sender in ignore_list: @@ -151,12 +159,14 @@ async def filter_events_for_client( # get the room_visibility at the time of the event. visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) if visibility_event: - visibility = visibility_event.content.get("history_visibility", "shared") + visibility = visibility_event.content.get( + "history_visibility", HistoryVisibility.SHARED + ) else: - visibility = "shared" + visibility = HistoryVisibility.SHARED if visibility not in VISIBILITY_PRIORITY: - visibility = "shared" + visibility = HistoryVisibility.SHARED # Always allow history visibility events on boundaries. This is done # by setting the effective visibility to the least restrictive @@ -166,7 +176,7 @@ async def filter_events_for_client( prev_visibility = prev_content.get("history_visibility", None) if prev_visibility not in VISIBILITY_PRIORITY: - prev_visibility = "shared" + prev_visibility = HistoryVisibility.SHARED new_priority = VISIBILITY_PRIORITY.index(visibility) old_priority = VISIBILITY_PRIORITY.index(prev_visibility) @@ -211,17 +221,17 @@ async def filter_events_for_client( # otherwise, it depends on the room visibility. - if visibility == "joined": + if visibility == HistoryVisibility.JOINED: # we weren't a member at the time of the event, so we can't # see this event. return None - elif visibility == "invited": + elif visibility == HistoryVisibility.INVITED: # user can also see the event if they were *invited* at the time # of the event. return event if membership == Membership.INVITE else None - elif visibility == "shared" and is_peeking: + elif visibility == HistoryVisibility.SHARED and is_peeking: # if the visibility is shared, users cannot see the event unless # they have *subequently* joined the room (or were members at the # time, of course) @@ -285,8 +295,10 @@ async def filter_events_for_server( def check_event_is_visible(event, state): history = state.get((EventTypes.RoomHistoryVisibility, ""), None) if history: - visibility = history.content.get("history_visibility", "shared") - if visibility in ["invited", "joined"]: + visibility = history.content.get( + "history_visibility", HistoryVisibility.SHARED + ) + if visibility in [HistoryVisibility.INVITED, HistoryVisibility.JOINED]: # We now loop through all state events looking for # membership states for the requesting server to determine # if the server is either in the room or has been invited @@ -306,7 +318,7 @@ async def filter_events_for_server( if memtype == Membership.JOIN: return True elif memtype == Membership.INVITE: - if visibility == "invited": + if visibility == HistoryVisibility.INVITED: return True else: # server has no users in the room: redact @@ -337,7 +349,8 @@ async def filter_events_for_server( else: event_map = await storage.main.get_events(visibility_ids) all_open = all( - e.content.get("history_visibility") in (None, "shared", "world_readable") + e.content.get("history_visibility") + in (None, HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE) for e in event_map.values() ) diff --git a/synctl b/synctl
index 9395ebd048..cfa9cec0c4 100755 --- a/synctl +++ b/synctl
@@ -358,6 +358,13 @@ def main(): for worker in workers: env = os.environ.copy() + # Skip starting a worker if its already running + if os.path.exists(worker.pidfile) and pid_running( + int(open(worker.pidfile).read()) + ): + print(worker.app + " already running") + continue + if worker.cache_factor: os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor) diff --git a/synmark/__init__.py b/synmark/__init__.py
index 53698bd5ab..3d4ec3e184 100644 --- a/synmark/__init__.py +++ b/synmark/__init__.py
@@ -15,55 +15,19 @@ import sys -from twisted.internet import epollreactor +try: + from twisted.internet.epollreactor import EPollReactor as Reactor +except ImportError: + from twisted.internet.pollreactor import PollReactor as Reactor from twisted.internet.main import installReactor -from synapse.config.homeserver import HomeServerConfig -from synapse.util import Clock - -from tests.utils import default_config, setup_test_homeserver - - -async def make_homeserver(reactor, config=None): - """ - Make a Homeserver suitable for running benchmarks against. - - Args: - reactor: A Twisted reactor to run under. - config: A HomeServerConfig to use, or None. - """ - cleanup_tasks = [] - clock = Clock(reactor) - - if not config: - config = default_config("test") - - config_obj = HomeServerConfig() - config_obj.parse_config_dict(config, "", "") - - hs = await setup_test_homeserver( - cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock - ) - stor = hs.get_datastore() - - # Run the database background updates. - if hasattr(stor.db_pool.updates, "do_next_background_update"): - while not await stor.db_pool.updates.has_completed_background_updates(): - await stor.db_pool.updates.do_next_background_update(1) - - def cleanup(): - for i in cleanup_tasks: - i() - - return hs, clock.sleep, cleanup - def make_reactor(): """ Instantiate and install a Twisted reactor suitable for testing (i.e. not the default global one). """ - reactor = epollreactor.EPollReactor() + reactor = Reactor() if "twisted.internet.reactor" in sys.modules: del sys.modules["twisted.internet.reactor"] diff --git a/synmark/__main__.py b/synmark/__main__.py
index 17df9ddeb7..de13c1a909 100644 --- a/synmark/__main__.py +++ b/synmark/__main__.py
@@ -12,20 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import sys from argparse import REMAINDER from contextlib import redirect_stderr from io import StringIO import pyperf -from synmark import make_reactor -from synmark.suites import SUITES from twisted.internet.defer import Deferred, ensureDeferred from twisted.logger import globalLogBeginner, textFileLogObserver from twisted.python.failure import Failure +from synmark import make_reactor +from synmark.suites import SUITES + from tests.utils import setupdb diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py
index d8e4c7d58f..c9d9cf761e 100644 --- a/synmark/suites/logging.py +++ b/synmark/suites/logging.py
@@ -13,20 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import warnings from io import StringIO from mock import Mock from pyperf import perf_counter -from synmark import make_homeserver from twisted.internet.defer import Deferred from twisted.internet.protocol import ServerFactory -from twisted.logger import LogBeginner, Logger, LogPublisher +from twisted.logger import LogBeginner, LogPublisher from twisted.protocols.basic import LineOnlyReceiver -from synapse.logging._structured import setup_structured_logging +from synapse.config.logger import _setup_stdlib_logging +from synapse.logging import RemoteHandler +from synapse.util import Clock class LineCounter(LineOnlyReceiver): @@ -62,7 +64,15 @@ async def main(reactor, loops): logger_factory.on_done = Deferred() port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1") - hs, wait, cleanup = await make_homeserver(reactor) + # A fake homeserver config. + class Config: + server_name = "synmark-" + str(loops) + no_redirect_stdio = True + + hs_config = Config() + + # To be able to sleep. + clock = Clock(reactor) errors = StringIO() publisher = LogPublisher() @@ -72,47 +82,49 @@ async def main(reactor, loops): ) log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { + "version": 1, + "loggers": {"synapse": {"level": "DEBUG", "handlers": ["tersejson"]}}, + "formatters": {"tersejson": {"class": "synapse.logging.TerseJsonFormatter"}}, + "handlers": { "tersejson": { - "type": "network_json_terse", + "class": "synapse.logging.RemoteHandler", "host": "127.0.0.1", "port": port.getHost().port, "maximum_buffer": 100, + "_reactor": reactor, } }, } - logger = Logger(namespace="synapse.logging.test_terse_json", observer=publisher) - logging_system = setup_structured_logging( - hs, hs.config, log_config, logBeginner=beginner, redirect_stdlib_logging=False + logger = logging.getLogger("synapse.logging.test_terse_json") + _setup_stdlib_logging( + hs_config, log_config, logBeginner=beginner, ) # Wait for it to connect... - await logging_system._observers[0]._service.whenConnected() + for handler in logging.getLogger("synapse").handlers: + if isinstance(handler, RemoteHandler): + break + else: + raise RuntimeError("Improperly configured: no RemoteHandler found.") + + await handler._service.whenConnected() start = perf_counter() # Send a bunch of useful messages for i in range(0, loops): - logger.info("test message %s" % (i,)) - - if ( - len(logging_system._observers[0]._buffer) - == logging_system._observers[0].maximum_buffer - ): - while ( - len(logging_system._observers[0]._buffer) - > logging_system._observers[0].maximum_buffer / 2 - ): - await wait(0.01) + logger.info("test message %s", i) + + if len(handler._buffer) == handler.maximum_buffer: + while len(handler._buffer) > handler.maximum_buffer / 2: + await clock.sleep(0.01) await logger_factory.on_done end = perf_counter() - start - logging_system.stop() + handler.close() port.stopListening() - cleanup() return end diff --git a/sytest-blacklist b/sytest-blacklist
index b563448016..de9986357b 100644 --- a/sytest-blacklist +++ b/sytest-blacklist
@@ -34,9 +34,6 @@ New federated private chats get full presence information (SYN-115) # this requirement from the spec Inbound federation of state requires event_id as a mandatory paramater -# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands -Can upload self-signing keys - # Blacklisted until MSC2753 is implemented Local users can peek into world_readable rooms by room ID We can't peek into rooms with shared history_visibility diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 8ab56ec94c..ee5217b074 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py
@@ -19,7 +19,6 @@ import pymacaroons from twisted.internet import defer -import synapse.handlers.auth from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -30,26 +29,22 @@ from synapse.api.errors import ( MissingClientTokenError, ResourceLimitError, ) +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID from tests import unittest from tests.utils import mock_getRawHeaders, setup_test_homeserver -class TestHandlers: - def __init__(self, hs): - self.auth_handler = synapse.handlers.auth.AuthHandler(hs) - - class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.state_handler = Mock() self.store = Mock() - self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) + self.hs = yield setup_test_homeserver(self.addCleanup) self.hs.get_datastore = Mock(return_value=self.store) - self.hs.handlers = TestHandlers(self.hs) + self.hs.get_auth_handler().store = self.store self.auth = Auth(self.hs) # AuthBlocking reads from the hs' config on initialization. We need to @@ -67,7 +62,9 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): - user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} + user_info = TokenLookupResult( + user_id=self.test_user, token_id=5, device_id="device" + ) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -90,7 +87,7 @@ class AuthTestCase(unittest.TestCase): self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): - user_info = {"name": self.test_user, "token_id": "ditto"} + user_info = TokenLookupResult(user_id=self.test_user, token_id=5) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -227,7 +224,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( return_value=defer.succeed( - {"name": "@baldrick:matrix.org", "device_id": "device"} + TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") ) ) @@ -243,12 +240,11 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(macaroon.serialize()) ) - user = user_info["user"] - self.assertEqual(UserID.from_string(user_id), user) + self.assertEqual(user_id, user_info.user_id) # TODO: device_id should come from the macaroon, but currently comes # from the db. - self.assertEqual(user_info["device_id"], "device") + self.assertEqual(user_info.device_id, "device") @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): @@ -270,10 +266,8 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(serialized) ) - user = user_info["user"] - is_guest = user_info["is_guest"] - self.assertEqual(UserID.from_string(user_id), user) - self.assertTrue(is_guest) + self.assertEqual(user_id, user_info.user_id) + self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) @defer.inlineCallbacks @@ -283,24 +277,25 @@ class AuthTestCase(unittest.TestCase): self.store.get_device = Mock(return_value=defer.succeed(None)) token = yield defer.ensureDeferred( - self.hs.handlers.auth_handler.get_access_token_for_user_id( + self.hs.get_auth_handler().get_access_token_for_user_id( USER_ID, "DEVICE", valid_until_ms=None ) ) self.store.add_access_token_to_user.assert_called_with( - USER_ID, token, "DEVICE", None + user_id=USER_ID, + token=token, + device_id="DEVICE", + valid_until_ms=None, + puppets_user_id=None, ) def get_user(tok): if token != tok: return defer.succeed(None) return defer.succeed( - { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + TokenLookupResult( + user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE", + ) ) self.store.get_user_by_access_token = get_user diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index d2d535d23c..279c94a03d 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py
@@ -16,8 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock - import jsonschema from twisted.internet import defer @@ -28,7 +26,7 @@ from synapse.api.filtering import Filter from synapse.events import make_event_from_dict from tests import unittest -from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver +from tests.utils import setup_test_homeserver user_localpart = "test_user" @@ -42,22 +40,9 @@ def MockEvent(**kwargs): class FilteringTestCase(unittest.TestCase): - @defer.inlineCallbacks def setUp(self): - self.mock_federation_resource = MockHttpResource() - - self.mock_http_client = Mock(spec=[]) - self.mock_http_client.put_json = DeferredMockCallable() - - hs = yield setup_test_homeserver( - self.addCleanup, - handlers=None, - http_client=self.mock_http_client, - keyring=Mock(), - ) - + hs = setup_test_homeserver(self.addCleanup) self.filtering = hs.get_filtering() - self.datastore = hs.get_datastore() def test_errors_on_invalid_filters(self): diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 1e1f30d790..fe504d0869 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py
@@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=True, + None, "example.com", id="foo", rate_limited=True, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) @@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=False, + None, "example.com", id="foo", rate_limited=False, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 641093d349..e0ca288829 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py
@@ -15,6 +15,7 @@ from synapse.app.generic_worker import GenericWorkerServer +from tests.server import make_request from tests.unittest import HomeserverTestCase @@ -22,7 +23,7 @@ class FrontendProxyTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=GenericWorkerServer + federation_http_client=None, homeserver_to_use=GenericWorkerServer ) return hs @@ -55,10 +56,8 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - self.resource = site.resource.children[b"_matrix"].children[b"client"] - request, channel = self.make_request("PUT", "presence/a/status") - self.render(request) + channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 400 + unrecognised, because nothing is registered self.assertEqual(channel.code, 400) @@ -77,10 +76,8 @@ class FrontendProxyTests(HomeserverTestCase): # Grab the resource from the site that was told to listen self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] - self.resource = site.resource.children[b"_matrix"].children[b"client"] - request, channel = self.make_request("PUT", "presence/a/status") - self.render(request) + channel = make_request(self.reactor, site, "PUT", "presence/a/status") # 401, because the stub servlet still checks authentication self.assertEqual(channel.code, 401) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 0f016c32eb..467033e201 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py
@@ -20,13 +20,14 @@ from synapse.app.generic_worker import GenericWorkerServer from synapse.app.homeserver import SynapseHomeServer from synapse.config.server import parse_listener_def +from tests.server import make_request from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=GenericWorkerServer + federation_http_client=None, homeserver_to_use=GenericWorkerServer ) return hs @@ -66,16 +67,15 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - self.resource = site.resource.children[b"_matrix"].children[b"federation"] + site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise - request, channel = self.make_request( - "GET", "/_matrix/federation/v1/openid/userinfo" + channel = make_request( + self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) - self.render(request) self.assertEqual(channel.code, 401) @@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=SynapseHomeServer + federation_http_client=None, homeserver_to_use=SynapseHomeServer ) return hs @@ -115,15 +115,14 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - self.resource = site.resource.children[b"_matrix"].children[b"federation"] + site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise - request, channel = self.make_request( - "GET", "/_matrix/federation/v1/openid/userinfo" + channel = make_request( + self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo" ) - self.render(request) self.assertEqual(channel.code, 401) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 236b608d58..0bffeb1150 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py
@@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase): def setUp(self): self.service = ApplicationService( id="unique_identifier", + sender="@as:test", url="some_url", token="some_token", hostname="matrix.org", # only used by get_groups_for_user diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 68a4caabbf..97f8cad0dd 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py
@@ -60,7 +60,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events # txn made and saved + service=service, events=events, ephemeral=[] # txn made and saved ) self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -81,7 +81,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events # txn made and saved + service=service, events=events, ephemeral=[] # txn made and saved ) self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.complete.call_count) # or completed @@ -106,7 +106,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events + service=service, events=events, ephemeral=[] ) self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer.recover.call_count) # and invoked @@ -202,26 +202,28 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): # Expect the event to be sent immediately. service = Mock(id=4) event = Mock() - self.queuer.enqueue(service, event) - self.txn_ctrl.send.assert_called_once_with(service, [event]) + self.queuer.enqueue_event(service, event) + self.txn_ctrl.send.assert_called_once_with(service, [event], []) def test_send_single_event_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d)) + self.txn_ctrl.send = Mock( + side_effect=lambda x, y, z: make_deferred_yieldable(d) + ) service = Mock(id=4) event = Mock(event_id="first") event2 = Mock(event_id="second") event3 = Mock(event_id="third") # Send an event and don't resolve it just yet. - self.queuer.enqueue(service, event) + self.queuer.enqueue_event(service, event) # Send more events: expect send() to NOT be called multiple times. - self.queuer.enqueue(service, event2) - self.queuer.enqueue(service, event3) - self.txn_ctrl.send.assert_called_with(service, [event]) + self.queuer.enqueue_event(service, event2) + self.queuer.enqueue_event(service, event3) + self.txn_ctrl.send.assert_called_with(service, [event], []) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3]) + self.txn_ctrl.send.assert_called_with(service, [event2, event3], []) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): @@ -239,21 +241,99 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): send_return_list = [srv_1_defer, srv_2_defer] - def do_send(x, y): + def do_send(x, y, z): return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) # send events for different ASes and make sure they are sent - self.queuer.enqueue(srv1, srv_1_event) - self.queuer.enqueue(srv1, srv_1_event2) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event]) - self.queuer.enqueue(srv2, srv_2_event) - self.queuer.enqueue(srv2, srv_2_event2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event]) + self.queuer.enqueue_event(srv1, srv_1_event) + self.queuer.enqueue_event(srv1, srv_1_event2) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], []) + self.queuer.enqueue_event(srv2, srv_2_event) + self.queuer.enqueue_event(srv2, srv_2_event2) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], []) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2]) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], []) + self.assertEquals(3, self.txn_ctrl.send.call_count) + + def test_send_large_txns(self): + srv_1_defer = defer.Deferred() + srv_2_defer = defer.Deferred() + send_return_list = [srv_1_defer, srv_2_defer] + + def do_send(x, y, z): + return make_deferred_yieldable(send_return_list.pop(0)) + + self.txn_ctrl.send = Mock(side_effect=do_send) + + service = Mock(id=4, name="service") + event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)] + for event in event_list: + self.queuer.enqueue_event(service, event) + + # Expect the first event to be sent immediately. + self.txn_ctrl.send.assert_called_with(service, [event_list[0]], []) + srv_1_defer.callback(service) + # Then send the next 100 events + self.txn_ctrl.send.assert_called_with(service, event_list[1:101], []) + srv_2_defer.callback(service) + # Then the final 99 events + self.txn_ctrl.send.assert_called_with(service, event_list[101:], []) self.assertEquals(3, self.txn_ctrl.send.call_count) + + def test_send_single_ephemeral_no_queue(self): + # Expect the event to be sent immediately. + service = Mock(id=4, name="service") + event_list = [Mock(name="event")] + self.queuer.enqueue_ephemeral(service, event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + + def test_send_multiple_ephemeral_no_queue(self): + # Expect the event to be sent immediately. + service = Mock(id=4, name="service") + event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] + self.queuer.enqueue_ephemeral(service, event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + + def test_send_single_ephemeral_with_queue(self): + d = defer.Deferred() + self.txn_ctrl.send = Mock( + side_effect=lambda x, y, z: make_deferred_yieldable(d) + ) + service = Mock(id=4) + event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")] + event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")] + event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")] + + # Send an event and don't resolve it just yet. + self.queuer.enqueue_ephemeral(service, event_list_1) + # Send more events: expect send() to NOT be called multiple times. + self.queuer.enqueue_ephemeral(service, event_list_2) + self.queuer.enqueue_ephemeral(service, event_list_3) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1) + self.assertEquals(1, self.txn_ctrl.send.call_count) + # Resolve txn_ctrl.send + d.callback(service) + # Expect the queued events to be sent + self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3) + self.assertEquals(2, self.txn_ctrl.send.call_count) + + def test_send_large_txns_ephemeral(self): + d = defer.Deferred() + self.txn_ctrl.send = Mock( + side_effect=lambda x, y, z: make_deferred_yieldable(d) + ) + # Expect the event to be sent immediately. + service = Mock(id=4, name="service") + first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)] + second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] + event_list = first_chunk + second_chunk + self.queuer.enqueue_ephemeral(service, event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk) + d.callback(service) + self.txn_ctrl.send.assert_called_with(service, [], second_chunk) + self.assertEquals(2, self.txn_ctrl.send.call_count) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 8ff1460c0d..1d65ea2f9c 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py
@@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): return val def test_verify_json_objects_for_server_awaits_previous_requests(self): - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock() kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) @@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): """Tests that we correctly handle key requests for keys we've stored with a null `ts_valid_until_ms` """ - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) kr = keyring.Keyring( @@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): } } - mock_fetcher = keyring.KeyFetcher() + mock_fetcher = Mock() mock_fetcher.get_keys = Mock(side_effect=get_keys) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) @@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase): } } - mock_fetcher1 = keyring.KeyFetcher() + mock_fetcher1 = Mock() mock_fetcher1.get_keys = Mock(side_effect=get_keys1) - mock_fetcher2 = keyring.KeyFetcher() + mock_fetcher2 = Mock() mock_fetcher2.get_keys = Mock(side_effect=get_keys2) kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2)) @@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.http_client = Mock() - hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) + hs = self.setup_test_homeserver(federation_http_client=self.http_client) return hs def test_get_keys_from_server(self): @@ -396,7 +396,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): ] return self.setup_test_homeserver( - handlers=None, http_client=self.http_client, config=config + federation_http_client=self.http_client, config=config ) def build_perspectives_response( diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 1471cc1a28..9ccd2d76b8 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py
@@ -48,10 +48,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): ) # Get the room complexity - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.render(request) self.assertEquals(200, channel.code) complexity = channel.json_body["v1"] self.assertTrue(complexity > 0, complexity) @@ -61,10 +60,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.render(request) self.assertEquals(200, channel.code) complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index da933ecd75..cfeccc0577 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py
@@ -46,12 +46,11 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): "/get_missing_events/(?P<room_id>[^/]*)/?" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), query_content, ) - self.render(request) self.assertEquals(400, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") @@ -96,10 +95,9 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): room_1 = self.helper.create_room_as(u1, tok=u1_token) self.inject_room_member(room_1, "@user:other.example.com", "join") - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertEqual( @@ -129,10 +127,9 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): room_1 = self.helper.create_room_as(u1, tok=u1_token) - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) - self.render(request) self.assertEquals(403, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") diff --git a/tests/federation/transport/__init__.py b/tests/federation/transport/__init__.py new file mode 100644
index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/federation/transport/__init__.py
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 72e22d655f..85500e169c 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,40 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.federation.transport import server -from synapse.util.ratelimitutils import FederationRateLimiter - from tests import unittest from tests.unittest import override_config -class RoomDirectoryFederationTests(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): - class Authenticator: - def authenticate_request(self, request, content): - return defer.succeed("otherserver.nottld") - - ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig()) - server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - +class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): @override_config({"allow_public_rooms_over_federation": False}) def test_blocked_public_room_list_over_federation(self): - request, channel = self.make_request( - "GET", "/_matrix/federation/v1/publicRooms" + """Test that unauthenticated requests to the public rooms directory 403 when + allow_public_rooms_over_federation is False. + """ + channel = self.make_request( + "GET", + "/_matrix/federation/v1/publicRooms", + federation_auth_origin=b"example.com", ) - self.render(request) self.assertEquals(403, channel.code) @override_config({"allow_public_rooms_over_federation": True}) def test_open_public_room_list_over_federation(self): - request, channel = self.make_request( - "GET", "/_matrix/federation/v1/publicRooms" + """Test that unauthenticated requests to the public rooms directory 200 when + allow_public_rooms_over_federation is True. + """ + channel = self.make_request( + "GET", + "/_matrix/federation/v1/publicRooms", + federation_auth_origin=b"example.com", ) - self.render(request) self.assertEquals(200, channel.code) diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index fc37c4328c..5c2b4de1a6 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py
@@ -35,7 +35,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() self.user1 = self.register_user("user1", "password") self.token1 = self.login("user1", "password") diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 2a0b7c1b56..53763cd0f9 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py
@@ -18,6 +18,7 @@ from mock import Mock from twisted.internet import defer from synapse.handlers.appservice import ApplicationServicesHandler +from synapse.types import RoomStreamToken from tests.test_utils import make_awaitable from tests.utils import MockClock @@ -41,7 +42,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() self.handler = ApplicationServicesHandler(hs) - @defer.inlineCallbacks def test_notify_interested_services(self): interested_service = self._mkservice(is_interested=True) services = [ @@ -61,12 +61,12 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred(self.handler.notify_interested_services(0)) + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.mock_scheduler.submit_event_for_as.assert_called_once_with( interested_service, event ) - @defer.inlineCallbacks def test_query_user_exists_unknown_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] @@ -80,10 +80,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred(self.handler.notify_interested_services(0)) + + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) - @defer.inlineCallbacks def test_query_user_exists_known_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] @@ -97,7 +98,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred(self.handler.notify_interested_services(0)) + + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.assertFalse( self.mock_as_api.query_user.called, "query_user called when it shouldn't have been.", diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 97877c2e42..e24ce81284 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py
@@ -21,24 +21,17 @@ from twisted.internet import defer import synapse import synapse.api.errors from synapse.api.errors import ResourceLimitError -from synapse.handlers.auth import AuthHandler from tests import unittest from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver -class AuthHandlers: - def __init__(self, hs): - self.auth_handler = AuthHandler(hs) - - class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) - self.hs.handlers = AuthHandlers(self.hs) - self.auth_handler = self.hs.handlers.auth_handler + self.hs = yield setup_test_homeserver(self.addCleanup) + self.auth_handler = self.hs.get_auth_handler() self.macaroon_generator = self.hs.get_macaroon_generator() # MAU tests @@ -59,7 +52,7 @@ class AuthTestCase(unittest.TestCase): self.fail("some_user was not in %s" % macaroon.inspect()) def test_macaroon_caveats(self): - self.hs.clock.now = 5000 + self.hs.get_clock().now = 5000 token = self.macaroon_generator.generate_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) @@ -85,7 +78,7 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_short_term_login_token_gives_user_id(self): - self.hs.clock.now = 1000 + self.hs.get_clock().now = 1000 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) user_id = yield defer.ensureDeferred( @@ -94,7 +87,7 @@ class AuthTestCase(unittest.TestCase): self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected - self.hs.clock.now = 6000 + self.hs.get_clock().now = 6000 with self.assertRaises(synapse.api.errors.AuthError): yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id(token) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py new file mode 100644
index 0000000000..bd7a1b6891 --- /dev/null +++ b/tests/handlers/test_cas.py
@@ -0,0 +1,121 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.handlers.cas_handler import CasResponse + +from tests.test_utils import simple_async_mock +from tests.unittest import HomeserverTestCase + +# These are a few constants that are used as config parameters in the tests. +BASE_URL = "https://synapse/" +SERVER_URL = "https://issuer/" + + +class CasHandlerTestCase(HomeserverTestCase): + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + cas_config = { + "enabled": True, + "server_url": SERVER_URL, + "service_url": BASE_URL, + } + config["cas_config"] = cas_config + + return config + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + self.handler = hs.get_cas_handler() + + # Reduce the number of attempts when generating MXIDs. + sso_handler = hs.get_sso_handler() + sso_handler._MAP_USERNAME_RETRIES = 3 + + return hs + + def test_map_cas_user_to_user(self): + """Ensure that mapping the CAS user returned from a provider to an MXID works properly.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + def test_map_cas_user_to_existing_user(self): + """Existing users can log in with CAS account.""" + store = self.hs.get_datastore() + self.get_success( + store.register_user(user_id="@test_user:test", password_hash=None) + ) + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # Map a user via SSO. + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + # Subsequent calls should map to the same mxid. + auth_handler.complete_sso_login.reset_mock() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + def test_map_cas_user_to_invalid_localpart(self): + """CAS automaps invalid characters to base-64 encoding.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + cas_response = CasResponse("föö", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@f=c3=b6=c3=b6:test", request, "redirect_uri", None + ) + + +def _mock_request(): + """Returns a mock which will stand in as a SynapseRequest""" + return Mock(spec=["getClientIP", "get_user_agent"]) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 969d44c787..5dfeccfeb6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py
@@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +27,7 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver("server", http_client=None) + hs = self.setup_test_homeserver("server", federation_http_client=None) self.handler = hs.get_device_handler() self.store = hs.get_datastore() return hs @@ -224,3 +225,84 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) ) self.reactor.advance(1000) + + +class DehydrationTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", federation_http_client=None) + self.handler = hs.get_device_handler() + self.registration = hs.get_registration_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + return hs + + def test_dehydrate_and_rehydrate_device(self): + user_id = "@boris:dehydration" + + self.get_success(self.store.register_user(user_id, "foobar")) + + # First check if we can store and fetch a dehydrated device + stored_dehydrated_device_id = self.get_success( + self.handler.store_dehydrated_device( + user_id=user_id, + device_data={"device_data": {"foo": "bar"}}, + initial_device_display_name="dehydrated device", + ) + ) + + retrieved_device_id, device_data = self.get_success( + self.handler.get_dehydrated_device(user_id=user_id) + ) + + self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) + self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) + + # Create a new login for the user and dehydrated the device + device_id, access_token = self.get_success( + self.registration.register_device( + user_id=user_id, device_id=None, initial_display_name="new device", + ) + ) + + # Trying to claim a nonexistent device should throw an error + self.get_failure( + self.handler.rehydrate_device( + user_id=user_id, + access_token=access_token, + device_id="not the right device ID", + ), + synapse.api.errors.NotFoundError, + ) + + # dehydrating the right devices should succeed and change our device ID + # to the dehydrated device's ID + res = self.get_success( + self.handler.rehydrate_device( + user_id=user_id, + access_token=access_token, + device_id=retrieved_device_id, + ) + ) + + self.assertEqual(res, {"success": True}) + + # make sure that our device ID has changed + user_info = self.get_success(self.auth.get_user_by_access_token(access_token)) + + self.assertEqual(user_info.device_id, retrieved_device_id) + + # make sure the device has the display name that was set from the login + res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) + + self.assertEqual(res["display_name"], "new device") + + # make sure that the device ID that we were initially assigned no longer exists + self.get_failure( + self.handler.get_device(user_id, device_id), + synapse.api.errors.NotFoundError, + ) + + # make sure that there's no device available for dehydrating now + ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) + + self.assertIsNone(ret) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index bc0c5aefdc..a39f898608 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py
@@ -42,13 +42,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.mock_registry.register_query_handler = register_query_handler hs = self.setup_test_homeserver( - http_client=None, - resource_for_federation=Mock(), federation_client=self.mock_federation, federation_registry=self.mock_registry, ) - self.handler = hs.get_handlers().directory_handler + self.handler = hs.get_directory_handler() self.store = hs.get_datastore() @@ -110,7 +108,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.handler = hs.get_handlers().directory_handler + self.handler = hs.get_directory_handler() # Create user self.admin_user = self.register_user("admin", "pass", admin=True) @@ -173,7 +171,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.handler = hs.get_handlers().directory_handler + self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() # Create user @@ -289,7 +287,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.handler = hs.get_handlers().directory_handler + self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() # Create user @@ -407,23 +405,21 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): def test_denied(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/room/%23test%3Atest", ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), ) - self.render(request) self.assertEquals(403, channel.code, channel.result) def test_allowed(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/room/%23unofficial_test%3Atest", ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), ) - self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -435,14 +431,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.room_list_handler = hs.get_room_list_handler() - self.directory_handler = hs.get_handlers().directory_handler + self.directory_handler = hs.get_directory_handler() return hs @@ -451,8 +446,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = True # Room list is enabled so we should get some results - request, channel = self.make_request("GET", b"publicRooms") - self.render(request) + channel = self.make_request("GET", b"publicRooms") self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) @@ -460,15 +454,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = False # Room list disabled so we should get no results - request, channel = self.make_request("GET", b"publicRooms") - self.render(request) + channel = self.make_request("GET", b"publicRooms") self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) # Room list disabled so we shouldn't be allowed to publish rooms room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request( + channel = self.make_request( "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) - self.render(request) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 366dcfb670..924f29f051 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py
@@ -33,13 +33,15 @@ class E2eKeysHandlerTestCase(unittest.TestCase): super().__init__(*args, **kwargs) self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler + self.store = None # type: synapse.storage.Storage @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - self.addCleanup, handlers=None, federation_client=mock.Mock() + self.addCleanup, federation_client=mock.Mock() ) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) + self.store = self.hs.get_datastore() @defer.inlineCallbacks def test_query_local_devices_no_devices(self): @@ -172,6 +174,89 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) @defer.inlineCallbacks + def test_fallback_key(self): + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + fallback_key = {"alg1:k1": "key1"} + otk = {"alg1:k2": "key2"} + + # we shouldn't have any unused fallback keys yet + res = yield defer.ensureDeferred( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, []) + + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"org.matrix.msc2732.fallback_keys": fallback_key}, + ) + ) + + # we should now have an unused alg1 key + res = yield defer.ensureDeferred( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, ["alg1"]) + + # claiming an OTK when no OTKs are available should return the fallback + # key + res = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, + ) + + # we shouldn't have any unused fallback keys again + res = yield defer.ensureDeferred( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, []) + + # claiming an OTK again should return the same fallback key + res = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, + ) + + # if the user uploads a one-time key, the next claim should fetch the + # one-time key, and then go back to the fallback + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": otk} + ) + ) + + res = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, + ) + + res = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, + ) + + @defer.inlineCallbacks def test_replace_master_key(self): """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 7adde9b9de..45f201a399 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py
@@ -54,7 +54,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - self.addCleanup, handlers=None, replication_layer=mock.Mock() + self.addCleanup, replication_layer=mock.Mock() ) self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) self.local_user = "@boris:" + self.hs.hostname diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 96fea58673..0b24b89a2e 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py
@@ -37,8 +37,8 @@ class FederationTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver(http_client=None) - self.handler = hs.get_handlers().federation_handler + hs = self.setup_test_homeserver(federation_http_client=None) + self.handler = hs.get_federation_handler() self.store = hs.get_datastore() return hs @@ -59,7 +59,6 @@ class FederationTestCase(unittest.HomeserverTestCase): ) d = self.handler.on_exchange_third_party_invite_request( - room_id=room_id, event_dict={ "type": EventTypes.Member, "room_id": room_id, @@ -127,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - with LoggingContext(request="send_rejected"): + with LoggingContext("send_rejected"): d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) self.get_success(d) @@ -179,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - with LoggingContext(request="send_rejected"): + with LoggingContext("send_rejected"): d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) self.get_success(d) @@ -199,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase): # the auth code requires that a signature exists, but doesn't check that # signature... go figure. join_event.signatures[other_server] = {"x": "y"} - with LoggingContext(request="send_join"): + with LoggingContext("send_join"): d = run_in_background( self.handler.on_send_join_request, other_server, join_event ) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py new file mode 100644
index 0000000000..f955dfa490 --- /dev/null +++ b/tests/handlers/test_message.py
@@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Tuple + +from synapse.api.constants import EventTypes +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.types import create_requester +from synapse.util.stringutils import random_string + +from tests import unittest + +logger = logging.getLogger(__name__) + + +class EventCreationTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = self.hs.get_event_creation_handler() + self.persist_event_storage = self.hs.get_storage().persistence + + self.user_id = self.register_user("tester", "foobar") + self.access_token = self.login("tester", "foobar") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) + + self.info = self.get_success( + self.hs.get_datastore().get_user_by_access_token(self.access_token,) + ) + self.token_id = self.info.token_id + + self.requester = create_requester(self.user_id, access_token_id=self.token_id) + + def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]: + """Create a new event with the given transaction ID. All events produced + by this method will be considered duplicates. + """ + + # We create a new event with a random body, as otherwise we'll produce + # *exactly* the same event with the same hash, and so same event ID. + return self.get_success( + self.handler.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.requester.user.to_string(), + "content": {"msgtype": "m.text", "body": random_string(5)}, + }, + txn_id=txn_id, + ) + ) + + def test_duplicated_txn_id(self): + """Test that attempting to handle/persist an event with a transaction ID + that has already been persisted correctly returns the old event and does + *not* produce duplicate messages. + """ + + txn_id = "something_suitably_random" + + event1, context = self._create_duplicate_event(txn_id) + + ret_event1 = self.get_success( + self.handler.handle_new_client_event(self.requester, event1, context) + ) + stream_id1 = ret_event1.internal_metadata.stream_ordering + + self.assertEqual(event1.event_id, ret_event1.event_id) + + event2, context = self._create_duplicate_event(txn_id) + + # We want to test that the deduplication at the persit event end works, + # so we want to make sure we test with different events. + self.assertNotEqual(event1.event_id, event2.event_id) + + ret_event2 = self.get_success( + self.handler.handle_new_client_event(self.requester, event2, context) + ) + stream_id2 = ret_event2.internal_metadata.stream_ordering + + # Assert that the returned values match those from the initial event + # rather than the new one. + self.assertEqual(ret_event1.event_id, ret_event2.event_id) + self.assertEqual(stream_id1, stream_id2) + + # Let's test that calling `persist_event` directly also does the right + # thing. + event3, context = self._create_duplicate_event(txn_id) + self.assertNotEqual(event1.event_id, event3.event_id) + + ret_event3, event_pos3, _ = self.get_success( + self.persist_event_storage.persist_event(event3, context) + ) + + # Assert that the returned values match those from the initial event + # rather than the new one. + self.assertEqual(ret_event1.event_id, ret_event3.event_id) + self.assertEqual(stream_id1, event_pos3.stream) + + # Let's test that calling `persist_events` directly also does the right + # thing. + event4, context = self._create_duplicate_event(txn_id) + self.assertNotEqual(event1.event_id, event3.event_id) + + events, _ = self.get_success( + self.persist_event_storage.persist_events([(event3, context)]) + ) + ret_event4 = events[0] + + # Assert that the returned values match those from the initial event + # rather than the new one. + self.assertEqual(ret_event1.event_id, ret_event4.event_id) + + def test_duplicated_txn_id_one_call(self): + """Test that we correctly handle duplicates that we try and persist at + the same time. + """ + + txn_id = "something_else_suitably_random" + + # Create two duplicate events to persist at the same time + event1, context1 = self._create_duplicate_event(txn_id) + event2, context2 = self._create_duplicate_event(txn_id) + + # Ensure their event IDs are different to start with + self.assertNotEqual(event1.event_id, event2.event_id) + + events, _ = self.get_success( + self.persist_event_storage.persist_events( + [(event1, context1), (event2, context2)] + ) + ) + + # Check that we've deduplicated the events. + self.assertEqual(len(events), 2) + self.assertEqual(events[0].event_id, events[1].event_id) + + +class ServerAclValidationTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("tester", "foobar") + self.access_token = self.login("tester", "foobar") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) + + def test_allow_server_acl(self): + """Test that sending an ACL that blocks everyone but ourselves works. + """ + + self.helper.send_state( + self.room_id, + EventTypes.ServerACL, + body={"allow": [self.hs.hostname]}, + tok=self.access_token, + expect_code=200, + ) + + def test_deny_server_acl_block_outselves(self): + """Test that sending an ACL that blocks ourselves does not work. + """ + self.helper.send_state( + self.room_id, + EventTypes.ServerACL, + body={}, + tok=self.access_token, + expect_code=400, + ) + + def test_deny_redact_server_acl(self): + """Test that attempting to redact an ACL is blocked. + """ + + body = self.helper.send_state( + self.room_id, + EventTypes.ServerACL, + body={"allow": [self.hs.hostname]}, + tok=self.access_token, + expect_code=200, + ) + event_id = body["event_id"] + + # Redaction of event should fail. + path = "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room_id, event_id) + channel = self.make_request( + "POST", path, content={}, access_token=self.access_token + ) + self.assertEqual(int(channel.result["code"]), 403) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index d5087e58be..368d600b33 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -12,40 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json -from urllib.parse import parse_qs, urlparse +import re +from typing import Dict +from urllib.parse import parse_qs, urlencode, urlparse -from mock import Mock, patch +from mock import ANY, Mock, patch -import attr import pymacaroons -from twisted.python.failure import Failure -from twisted.web._newclient import ResponseDone +from twisted.web.resource import Resource -from synapse.handlers.oidc_handler import ( - MappingException, - OidcError, - OidcHandler, - OidcMappingProvider, -) +from synapse.api.errors import RedirectException +from synapse.handlers.oidc_handler import OidcError +from synapse.handlers.sso import MappingException +from synapse.rest.client.v1 import login +from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.server import HomeServer from synapse.types import UserID +from tests.test_utils import FakeResponse, simple_async_mock from tests.unittest import HomeserverTestCase, override_config - -@attr.s -class FakeResponse: - code = attr.ib() - body = attr.ib() - phrase = attr.ib() - - def deliverBody(self, protocol): - protocol.dataReceived(self.body) - protocol.connectionLost(Failure(ResponseDone())) - - # These are a few constants that are used as config parameters in the tests. ISSUER = "https://issuer/" CLIENT_ID = "test-client-id" @@ -75,11 +63,14 @@ COOKIE_NAME = b"oidc_session" COOKIE_PATH = "/_synapse/oidc" -class TestMappingProvider(OidcMappingProvider): +class TestMappingProvider: @staticmethod def parse_config(config): return + def __init__(self, config): + pass + def get_remote_user_id(self, userinfo): return userinfo["sub"] @@ -94,14 +85,12 @@ class TestMappingProviderExtra(TestMappingProvider): return {"phone": userinfo["phone"]} -def simple_async_mock(return_value=None, raises=None): - # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args, **kwargs): - if raises: - raise raises - return return_value - - return Mock(side_effect=cb) +class TestMappingProviderFailures(TestMappingProvider): + async def map_user_attributes(self, userinfo, token, failures): + return { + "localpart": userinfo["username"] + (str(failures) if failures else ""), + "display_name": None, + } async def get_json(url): @@ -124,22 +113,16 @@ async def get_json(url): class OidcHandlerTestCase(HomeserverTestCase): - def make_homeserver(self, reactor, clock): - - self.http_client = Mock(spec=["get_json"]) - self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = "Synapse Test" - - config = self.default_config() + def default_config(self): + config = super().default_config() config["public_baseurl"] = BASE_URL - oidc_config = {} - oidc_config["enabled"] = True - oidc_config["client_id"] = CLIENT_ID - oidc_config["client_secret"] = CLIENT_SECRET - oidc_config["issuer"] = ISSUER - oidc_config["scopes"] = SCOPES - oidc_config["user_mapping_provider"] = { - "module": __name__ + ".TestMappingProvider", + oidc_config = { + "enabled": True, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "issuer": ISSUER, + "scopes": SCOPES, + "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, } # Update this config with what's in the default config so that @@ -147,13 +130,24 @@ class OidcHandlerTestCase(HomeserverTestCase): oidc_config.update(config.get("oidc_config", {})) config["oidc_config"] = oidc_config - hs = self.setup_test_homeserver( - http_client=self.http_client, - proxied_http_client=self.http_client, - config=config, - ) + return config - self.handler = OidcHandler(hs) + def make_homeserver(self, reactor, clock): + + self.http_client = Mock(spec=["get_json"]) + self.http_client.get_json.side_effect = get_json + self.http_client.user_agent = "Synapse Test" + + hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + + self.handler = hs.get_oidc_handler() + sso_handler = hs.get_sso_handler() + # Mock the render error method. + self.render_error = Mock(return_value=None) + sso_handler.render_error = self.render_error + + # Reduce the number of attempts when generating MXIDs. + sso_handler._MAP_USERNAME_RETRIES = 3 return hs @@ -161,12 +155,13 @@ class OidcHandlerTestCase(HomeserverTestCase): return patch.dict(self.handler._provider_metadata, values) def assertRenderedError(self, error, error_description=None): - args = self.handler._render_error.call_args[0] + args = self.render_error.call_args[0] self.assertEqual(args[1], error) if error_description is not None: self.assertEqual(args[2], error_description) # Reset the render_error mock - self.handler._render_error.reset_mock() + self.render_error.reset_mock() + return args def test_config(self): """Basic config correctly sets up the callback URL and client auth correctly.""" @@ -286,9 +281,15 @@ class OidcHandlerTestCase(HomeserverTestCase): h._validate_metadata, ) - # Tests for configs that the userinfo endpoint + # Tests for configs that require the userinfo endpoint self.assertFalse(h._uses_userinfo) - h._scopes = [] # do not request the openid scope + self.assertEqual(h._user_profile_method, "auto") + h._user_profile_method = "userinfo_endpoint" + self.assertTrue(h._uses_userinfo) + + # Revert the profile method and do not request the "openid" scope. + h._user_profile_method = "auto" + h._scopes = [] self.assertTrue(h._uses_userinfo) self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata) @@ -350,7 +351,6 @@ class OidcHandlerTestCase(HomeserverTestCase): def test_callback_error(self): """Errors from the provider returned in the callback are displayed.""" - self.handler._render_error = Mock() request = Mock(args={}) request.args[b"error"] = [b"invalid_client"] self.get_success(self.handler.handle_oidc_callback(request)) @@ -371,25 +371,29 @@ class OidcHandlerTestCase(HomeserverTestCase): - when the userinfo fetching fails - when the code exchange fails """ + + # ensure that we are correctly testing the fallback when "get_extra_attributes" + # is not implemented. + mapping_provider = self.handler._user_mapping_provider + with self.assertRaises(AttributeError): + _ = mapping_provider.get_extra_attributes + token = { "type": "bearer", "id_token": "id_token", "access_token": "access_token", } + username = "bar" userinfo = { "sub": "foo", - "preferred_username": "bar", + "username": username, } - user_id = "@foo:domain.org" - self.handler._render_error = Mock(return_value=None) + expected_user_id = "@%s:%s" % (username, self.hs.hostname) self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._parse_id_token = simple_async_mock(return_value=userinfo) self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) - self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock( - spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] - ) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() code = "code" state = "state" @@ -397,67 +401,56 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url = "http://client/redirect" user_agent = "Browser" ip_address = "10.0.0.1" - request.getCookie.return_value = self.handler._generate_oidc_session_token( + session = self.handler._generate_oidc_session_token( state=state, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - - request.args = {} - request.args[b"code"] = [code.encode("utf-8")] - request.args[b"state"] = [state.encode("utf-8")] - - request.requestHeaders = Mock(spec=["getRawHeaders"]) - request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] - request.getClientIP.return_value = ip_address + request = _build_callback_request( + code, state, session, user_agent=user_agent, ip_address=ip_address + ) self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {}, + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, request, client_redirect_url, None, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with( - userinfo, token, user_agent, ip_address - ) self.handler._fetch_userinfo.assert_not_called() - self.handler._render_error.assert_not_called() + self.render_error.assert_not_called() # Handle mapping errors - self.handler._map_userinfo_to_user = simple_async_mock( - raises=MappingException() - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("mapping_error") - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + with patch.object( + self.handler, + "_remote_id_from_userinfo", + new=Mock(side_effect=MappingException()), + ): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mapping_error") # Handle ID token errors self.handler._parse_id_token = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - self.handler._auth_handler.complete_sso_login.reset_mock() + auth_handler.complete_sso_login.reset_mock() self.handler._exchange_code.reset_mock() self.handler._parse_id_token.reset_mock() - self.handler._map_userinfo_to_user.reset_mock() self.handler._fetch_userinfo.reset_mock() # With userinfo fetching self.handler._scopes = [] # do not ask the "openid" scope self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {}, + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, request, client_redirect_url, None, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with( - userinfo, token, user_agent, ip_address - ) self.handler._fetch_userinfo.assert_called_once_with(token) - self.handler._render_error.assert_not_called() + self.render_error.assert_not_called() # Handle userinfo fetching error self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) @@ -473,7 +466,6 @@ class OidcHandlerTestCase(HomeserverTestCase): def test_callback_session(self): """The callback verifies the session presence and validity""" - self.handler._render_error = Mock(return_value=None) request = Mock(spec=["args", "getCookie", "addCookie"]) # Missing cookie @@ -607,66 +599,55 @@ class OidcHandlerTestCase(HomeserverTestCase): } userinfo = { "sub": "foo", + "username": "foo", "phone": "1234567", } - user_id = "@foo:domain.org" self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._parse_id_token = simple_async_mock(return_value=userinfo) - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) - self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock( - spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] - ) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() state = "state" client_redirect_url = "http://client/redirect" - request.getCookie.return_value = self.handler._generate_oidc_session_token( + session = self.handler._generate_oidc_session_token( state=state, nonce="nonce", client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - - request.args = {} - request.args[b"code"] = [b"code"] - request.args[b"state"] = [state.encode("utf-8")] - - request.requestHeaders = Mock(spec=["getRawHeaders"]) - request.requestHeaders.getRawHeaders.return_value = [b"Browser"] - request.getClientIP.return_value = "10.0.0.1" + request = _build_callback_request("code", state, session) self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {"phone": "1234567"}, + auth_handler.complete_sso_login.assert_called_once_with( + "@foo:test", request, client_redirect_url, {"phone": "1234567"}, ) def test_map_userinfo_to_user(self): """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + userinfo = { "sub": "test_user", "username": "test_user", } - # The token doesn't matter with the default user mapping provider. - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", ANY, ANY, None, ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user_2:test", ANY, ANY, None, ) - self.assertEqual(mxid, "@test_user_2:test") + auth_handler.complete_sso_login.reset_mock() # Test if the mxid is already taken store = self.hs.get_datastore() @@ -675,30 +656,352 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + self.assertRenderedError( + "mapping_error", + "Mapping provider does not support de-duplicating Matrix IDs", ) - self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken") @override_config({"oidc_config": {"allow_existing_users": True}}) def test_map_userinfo_to_existing_user(self): """Existing users can log in with OpenID Connect when allow_existing_users is True.""" store = self.hs.get_datastore() - user4 = UserID.from_string("@test_user_4:test") + user = UserID.from_string("@test_user:test") + self.get_success( + store.register_user(user_id=user.to_string(), password_hash=None) + ) + + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # Map a user via SSO. + userinfo = { + "sub": "test", + "username": "test_user", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, None, + ) + auth_handler.complete_sso_login.reset_mock() + + # Subsequent calls should map to the same mxid. + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, None, + ) + auth_handler.complete_sso_login.reset_mock() + + # Note that a second SSO user can be mapped to the same Matrix ID. (This + # requires a unique sub, but something that maps to the same matrix ID, + # in this case we'll just use the same username. A more realistic example + # would be subs which are email addresses, and mapping from the localpart + # of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.) + userinfo = { + "sub": "test1", + "username": "test_user", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, None, + ) + auth_handler.complete_sso_login.reset_mock() + + # Register some non-exact matching cases. + user2 = UserID.from_string("@TEST_user_2:test") + self.get_success( + store.register_user(user_id=user2.to_string(), password_hash=None) + ) + user2_caps = UserID.from_string("@test_USER_2:test") + self.get_success( + store.register_user(user_id=user2_caps.to_string(), password_hash=None) + ) + + # Attempting to login without matching a name exactly is an error. + userinfo = { + "sub": "test2", + "username": "TEST_USER_2", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + args = self.assertRenderedError("mapping_error") + self.assertTrue( + args[2].startswith( + "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:" + ) + ) + + # Logging in when matching a name exactly should work. + user2 = UserID.from_string("@TEST_USER_2:test") + self.get_success( + store.register_user(user_id=user2.to_string(), password_hash=None) + ) + + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + "@TEST_USER_2:test", ANY, ANY, None, + ) + + def test_map_userinfo_to_invalid_localpart(self): + """If the mapping provider generates an invalid localpart it should be rejected.""" + self.get_success( + _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) + ) + self.assertRenderedError("mapping_error", "localpart is invalid: föö") + + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "module": __name__ + ".TestMappingProviderFailures" + } + } + } + ) + def test_map_userinfo_to_user_retries(self): + """The mapping provider can retry generating an MXID if the MXID is already in use.""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + store = self.hs.get_datastore() self.get_success( - store.register_user(user_id=user4.to_string(), password_hash=None) + store.register_user(user_id="@test_user:test", password_hash=None) ) userinfo = { - "sub": "test4", - "username": "test_user_4", + "sub": "test", + "username": "test_user", } - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + + # test_user is already taken, so test_user1 gets registered instead. + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user1:test", ANY, ANY, None, + ) + auth_handler.complete_sso_login.reset_mock() + + # Register all of the potential mxids for a particular OIDC username. + self.get_success( + store.register_user(user_id="@tester:test", password_hash=None) + ) + for i in range(1, 3): + self.get_success( + store.register_user(user_id="@tester%d:test" % i, password_hash=None) ) + + # Now attempt to map to a username, this will fail since all potential usernames are taken. + userinfo = { + "sub": "tester", + "username": "tester", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + self.assertRenderedError( + "mapping_error", "Unable to generate a Matrix ID from the SSO response" + ) + + def test_empty_localpart(self): + """Attempts to map onto an empty localpart should be rejected.""" + userinfo = { + "sub": "tester", + "username": "", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") + + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.username }}"} + } + } + } + ) + def test_null_localpart(self): + """Mapping onto a null localpart via an empty OIDC attribute should be rejected""" + userinfo = { + "sub": "tester", + "username": None, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") + + +class UsernamePickerTestCase(HomeserverTestCase): + servlets = [login.register_servlets] + + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + oidc_config = { + "enabled": True, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "issuer": ISSUER, + "scopes": SCOPES, + "user_mapping_provider": { + "config": {"display_name_template": "{{ user.displayname }}"} + }, + } + + # Update this config with what's in the default config so that + # override_config works as expected. + oidc_config.update(config.get("oidc_config", {})) + config["oidc_config"] = oidc_config + + # whitelist this client URI so we redirect straight to it rather than + # serving a confirmation page + config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + return d + + def test_username_picker(self): + """Test the happy path of a username picker flow.""" + client_redirect_url = "https://whitelisted.client" + + # first of all, mock up an OIDC callback to the OidcHandler, which should + # raise a RedirectException + userinfo = {"sub": "tester", "displayname": "Jonny"} + f = self.get_failure( + _make_callback_with_userinfo( + self.hs, userinfo, client_redirect_url=client_redirect_url + ), + RedirectException, + ) + + # check the Location and cookies returned by the RedirectException + self.assertEqual(f.value.location, b"/_synapse/client/pick_username") + cookieheader = f.value.cookies[0] + regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);") + m = regex.search(cookieheader) + if not m: + self.fail("cookie header %s does not match %s" % (cookieheader, regex)) + + # introspect the sso handler a bit to check that the username mapping session + # looks ok. + session_id = m.group(1).decode("ascii") + username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions + self.assertIn( + session_id, username_mapping_sessions, "session id not found in map" ) - self.assertEqual(mxid, "@test_user_4:test") + session = username_mapping_sessions[session_id] + self.assertEqual(session.remote_user_id, "tester") + self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.client_redirect_url, client_redirect_url) + + # the expiry time should be about 15 minutes away + expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) + self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + + # Now, submit a username to the username picker, which should serve a redirect + # back to the client + submit_path = f.value.location + b"/submit" + content = urlencode({b"username": b"bobby"}).encode("utf8") + chan = self.make_request( + "POST", + path=submit_path, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", cookieheader), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location starts with the requested redirect URL + self.assertEqual( + location_headers[0][: len(client_redirect_url)], client_redirect_url + ) + + # fish the login token out of the returned redirect uri + parts = urlparse(location_headers[0]) + query = parse_qs(parts.query) + login_token = query["loginToken"][0] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@bobby:test") + + +async def _make_callback_with_userinfo( + hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" +) -> None: + """Mock up an OIDC callback with the given userinfo dict + + We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, + and poke in the userinfo dict as if it were the response to an OIDC userinfo call. + + Args: + hs: the HomeServer impl to send the callback to. + userinfo: the OIDC userinfo dict + client_redirect_url: the URL to redirect to on success. + """ + handler = hs.get_oidc_handler() + handler._exchange_code = simple_async_mock(return_value={}) + handler._parse_id_token = simple_async_mock(return_value=userinfo) + handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + + state = "state" + session = handler._generate_oidc_session_token( + state=state, + nonce="nonce", + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, + ) + request = _build_callback_request("code", state, session) + + await handler.handle_oidc_callback(request) + + +def _build_callback_request( + code: str, + state: str, + session: str, + user_agent: str = "Browser", + ip_address: str = "10.0.0.1", +): + """Builds a fake SynapseRequest to mock the browser callback + + Returns a Mock object which looks like the SynapseRequest we get from a browser + after SSO (before we return to the client) + + Args: + code: the authorization code which would have been returned by the OIDC + provider + state: the "state" param which would have been passed around in the + query param. Should be the same as was embedded in the session in + _build_oidc_session. + session: the "session" which would have been passed around in the cookie. + user_agent: the user-agent to present + ip_address: the IP address to pretend the request came from + """ + request = Mock( + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "get_user_agent", + ] + ) + + request.getCookie.return_value = session + request.args = {} + request.args[b"code"] = [code.encode("utf-8")] + request.args[b"state"] = [state.encode("utf-8")] + request.getClientIP.return_value = ip_address + request.get_user_agent.return_value = user_agent + return request diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py new file mode 100644
index 0000000000..f816594ee4 --- /dev/null +++ b/tests/handlers/test_password_providers.py
@@ -0,0 +1,603 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the password_auth_provider interface""" + +from typing import Any, Type, Union + +from mock import Mock + +from twisted.internet import defer + +import synapse +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import devices +from synapse.types import JsonDict + +from tests import unittest +from tests.server import FakeChannel +from tests.unittest import override_config + +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + +# a mock instance which the dummy auth providers delegate to, so we can see what's going +# on +mock_password_provider = Mock() + + +class PasswordOnlyAuthProvider: + """A password_provider which only implements `check_password`.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def check_password(self, *args): + return mock_password_provider.check_password(*args) + + +class CustomAuthProvider: + """A password_provider which implements a custom login type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def get_supported_login_types(self): + return {"test.login_type": ["test_field"]} + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +class PasswordCustomAuthProvider: + """A password_provider which implements password login via `check_auth`, as well + as a custom type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def get_supported_login_types(self): + return {"m.login.password": ["password"], "test.login_type": ["test_field"]} + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +def providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given password auth providers""" + return { + "password_providers": [ + {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} + for provider in providers + ] + } + + +class PasswordAuthProviderTests(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + devices.register_servlets, + ] + + def setUp(self): + # we use a global mock device, so make sure we are starting with a clean slate + mock_password_provider.reset_mock() + super().setUp() + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_password_only_auth_provider_login(self): + # login flows should only have m.login.password + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(True) + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@u:test", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + mock_password_provider.reset_mock() + + # login with mxid should work too + channel = self._send_password_login("@u:bz", "p") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@u:bz", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with("@u:bz", "p") + mock_password_provider.reset_mock() + + # try a weird username / pass. Honestly it's unclear what we *expect* to happen + # in these cases, but at least we can guard against the API changing + # unexpectedly + channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with( + "@ USER🙂NAME :test", " pASS😢word " + ) + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_password_only_auth_provider_ui_auth(self): + """UI Auth should delegate correctly to the password provider""" + + # create the user, otherwise access doesn't work + module_api = self.hs.get_module_api() + self.get_success(module_api.register_user("u")) + + # log in twice, to get two devices + mock_password_provider.check_password.return_value = defer.succeed(True) + tok1 = self.login("u", "p") + self.login("u", "p", device_id="dev2") + mock_password_provider.reset_mock() + + # have the auth provider deny the request to start with + mock_password_provider.check_password.return_value = defer.succeed(False) + + # make the initial request which returns a 401 + session = self._start_delete_device_session(tok1, "dev2") + mock_password_provider.check_password.assert_not_called() + + # Make another request providing the UI auth flow. + channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + mock_password_provider.reset_mock() + + # Finally, check the request goes through when we allow it + mock_password_provider.check_password.return_value = defer.succeed(True) + channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") + self.assertEqual(channel.code, 200) + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_local_user_fallback_login(self): + """rejected login should fall back to local db""" + self.register_user("localuser", "localpass") + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 403, channel.result) + + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@localuser:test", channel.json_body["user_id"]) + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_local_user_fallback_ui_auth(self): + """rejected login should fall back to local db""" + self.register_user("localuser", "localpass") + + # have the auth provider deny the request + mock_password_provider.check_password.return_value = defer.succeed(False) + + # log in twice, to get two devices + tok1 = self.login("localuser", "localpass") + self.login("localuser", "localpass", device_id="dev2") + mock_password_provider.check_password.reset_mock() + + # first delete should give a 401 + session = self._start_delete_device_session(tok1, "dev2") + mock_password_provider.check_password.assert_not_called() + + # Wrong password + channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx") + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "xxx" + ) + mock_password_provider.reset_mock() + + # Right password + channel = self._authed_delete_device( + tok1, "dev2", session, "localuser", "localpass" + ) + self.assertEqual(channel.code, 200) + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_no_local_user_fallback_login(self): + """localdb_enabled can block login with the local password + """ + self.register_user("localuser", "localpass") + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_no_local_user_fallback_ui_auth(self): + """localdb_enabled can block ui auth with the local password + """ + self.register_user("localuser", "localpass") + + # allow login via the auth provider + mock_password_provider.check_password.return_value = defer.succeed(True) + + # log in twice, to get two devices + tok1 = self.login("localuser", "p") + self.login("localuser", "p", device_id="dev2") + mock_password_provider.check_password.reset_mock() + + # first delete should give a 401 + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # m.login.password UIA is permitted because the auth provider allows it, + # even though the localdb does not. + self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}]) + session = channel.json_body["session"] + mock_password_provider.check_password.assert_not_called() + + # now try deleting with the local password + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._authed_delete_device( + tok1, "dev2", session, "localuser", "localpass" + ) + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_auth_disabled(self): + """password auth doesn't work if it's disabled across the board""" + # login flows should be empty + flows = self._get_login_flows() + self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_password.assert_not_called() + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_login(self): + # login flows should have the custom flow and m.login.password, since we + # haven't disabled local password lookup. + # (password must come first, because reasons) + flows = self._get_login_flows() + self.assertEqual( + flows, + [{"type": "m.login.password"}, {"type": "test.login_type"}] + + ADDITIONAL_LOGIN_FLOWS, + ) + + # login with missing param should be rejected + channel = self._send_login("test.login_type", "u") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + channel = self._send_login("test.login_type", "u", test_field="y") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@user:bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + "u", "test.login_type", {"test_field": "y"} + ) + mock_password_provider.reset_mock() + + # try a weird username. Again, it's unclear what we *expect* to happen + # in these cases, but at least we can guard against the API changing + # unexpectedly + mock_password_provider.check_auth.return_value = defer.succeed( + "@ MALFORMED! :bz" + ) + channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + " USER🙂NAME ", "test.login_type", {"test_field": " abc "} + ) + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_ui_auth(self): + # register the user and log in twice, to get two devices + self.register_user("localuser", "localpass") + tok1 = self.login("localuser", "localpass") + self.login("localuser", "localpass", device_id="dev2") + + # make the initial request which returns a 401 + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"]) + session = channel.json_body["session"] + + # missing param + body = { + "auth": { + "type": "test.login_type", + "identifier": {"type": "m.id.user", "user": "localuser"}, + "session": session, + }, + } + + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 400) + # there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should + # use it... + self.assertIn("Missing parameters", channel.json_body["error"]) + mock_password_provider.check_auth.assert_not_called() + mock_password_provider.reset_mock() + + # right params, but authing as the wrong user + mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + body["auth"]["test_field"] = "foo" + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "foo"} + ) + mock_password_provider.reset_mock() + + # and finally, succeed + mock_password_provider.check_auth.return_value = defer.succeed( + "@localuser:test" + ) + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 200) + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "foo"} + ) + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_callback(self): + callback = Mock(return_value=defer.succeed(None)) + + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", callback) + ) + channel = self._send_login("test.login_type", "u", test_field="y") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@user:bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + "u", "test.login_type", {"test_field": "y"} + ) + + # check the args to the callback + callback.assert_called_once() + call_args, call_kwargs = callback.call_args + # should be one positional arg + self.assertEqual(len(call_args), 1) + self.assertEqual(call_args[0]["user_id"], "@user:bz") + for p in ["user_id", "access_token", "device_id", "home_server"]: + self.assertIn(p, call_args[0]) + + @override_config( + {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} + ) + def test_custom_auth_password_disabled(self): + """Test login with a custom auth provider where password login is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + @override_config( + { + **providers_config(CustomAuthProvider), + "password_config": {"enabled": False, "localdb_enabled": False}, + } + ) + def test_custom_auth_password_disabled_localdb_enabled(self): + """Check the localdb_enabled == enabled == False + + Regression test for https://github.com/matrix-org/synapse/issues/8914: check + that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't + cause an exception. + """ + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + @override_config( + { + **providers_config(PasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_login(self): + """log in with a custom auth provider which implements password, but password + login is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + @override_config( + { + **providers_config(PasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_ui_auth(self): + """UI Auth with a custom auth provider which implements password, but password + login is disabled""" + # register the user and log in twice via the test login type to get two devices, + self.register_user("localuser", "localpass") + mock_password_provider.check_auth.return_value = defer.succeed( + "@localuser:test" + ) + channel = self._send_login("test.login_type", "localuser", test_field="") + self.assertEqual(channel.code, 200, channel.result) + tok1 = channel.json_body["access_token"] + + channel = self._send_login( + "test.login_type", "localuser", test_field="", device_id="dev2" + ) + self.assertEqual(channel.code, 200, channel.result) + + # make the initial request which returns a 401 + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. In particular, "password" should *not* + # be present. + self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"]) + session = channel.json_body["session"] + + mock_password_provider.reset_mock() + + # check that auth with password is rejected + body = { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "localuser"}, + "password": "localpass", + "session": session, + }, + } + + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 400) + self.assertEqual( + "Password login has been disabled.", channel.json_body["error"] + ) + mock_password_provider.check_auth.assert_not_called() + mock_password_provider.reset_mock() + + # successful auth + body["auth"]["type"] = "test.login_type" + body["auth"]["test_field"] = "x" + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 200) + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "x"} + ) + + @override_config( + { + **providers_config(CustomAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_custom_auth_no_local_user_fallback(self): + """Test login with a custom auth provider where the local db is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # password login shouldn't work and should be rejected with a 400 + # ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + + def _get_login_flows(self) -> JsonDict: + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["flows"] + + def _send_password_login(self, user: str, password: str) -> FakeChannel: + return self._send_login(type="m.login.password", user=user, password=password) + + def _send_login(self, type, user, **params) -> FakeChannel: + params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type}) + channel = self.make_request("POST", "/_matrix/client/r0/login", params) + return channel + + def _start_delete_device_session(self, access_token, device_id) -> str: + """Make an initial delete device request, and return the UI Auth session ID""" + channel = self._delete_device(access_token, device_id) + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + return channel.json_body["session"] + + def _authed_delete_device( + self, + access_token: str, + device_id: str, + session: str, + user_id: str, + password: str, + ) -> FakeChannel: + """Make a delete device request, authenticating with the given uid/password""" + return self._delete_device( + access_token, + device_id, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": user_id}, + "password": password, + "session": session, + }, + }, + ) + + def _delete_device( + self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"", + ) -> FakeChannel: + """Delete an individual device.""" + channel = self.make_request( + "DELETE", "devices/" + device, body, access_token=access_token + ) + return channel diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 306dcfe944..0794b32c9c 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py
@@ -463,14 +463,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "server", http_client=None, federation_sender=Mock() + "server", federation_http_client=None, federation_sender=Mock() ) return hs def prepare(self, reactor, clock, hs): self.federation_sender = hs.get_federation_sender() self.event_builder_factory = hs.get_event_builder_factory() - self.federation_handler = hs.get_handlers().federation_handler + self.federation_handler = hs.get_federation_handler() self.presence_handler = hs.get_presence_handler() # self.event_builder_for_2 = EventBuilderFactory(hs) @@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.store.get_latest_event_ids_in_room(room_id) ) - event = self.get_success(builder.build(prev_event_ids)) + event = self.get_success(builder.build(prev_event_ids, None)) self.get_success(self.federation_handler.on_receive_pdu(hostname, event)) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 8e95e53d9e..919547556b 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -20,7 +20,6 @@ from twisted.internet import defer import synapse.types from synapse.api.errors import AuthError, SynapseError -from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID from tests import unittest @@ -28,11 +27,6 @@ from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver -class ProfileHandlers: - def __init__(self, hs): - self.profile_handler = MasterProfileHandler(hs) - - class ProfileTestCase(unittest.TestCase): """ Tests profile management. """ @@ -50,9 +44,6 @@ class ProfileTestCase(unittest.TestCase): hs = yield setup_test_homeserver( self.addCleanup, - http_client=None, - handlers=None, - resource_for_federation=Mock(), federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index cb7c0ed51a..bdf3d0a8a2 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -18,7 +18,6 @@ from mock import Mock from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError -from synapse.handlers.register import RegistrationHandler from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester @@ -29,11 +28,6 @@ from tests.utils import mock_getRawHeaders from .. import unittest -class RegistrationHandlers: - def __init__(self, hs): - self.registration_handler = RegistrationHandler(hs) - - class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ @@ -154,7 +148,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" user_id = self.get_success(self.handler.register_user(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -193,7 +187,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) self.get_failure(directory_handler.get_association(room_alias), SynapseError) @@ -205,7 +199,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -237,7 +231,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = self.get_success(self.handler.register_user(localpart="jeff")) # Ensure the room was created. - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -266,7 +260,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = self.get_success(self.handler.register_user(localpart="jeff")) # Ensure the room was created. - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -304,7 +298,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = self.get_success(self.handler.register_user(localpart="jeff")) # Ensure the room was created. - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -347,7 +341,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) # Ensure the room was created. - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -384,7 +378,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = self.get_success(self.handler.register_user(localpart="jeff")) # Ensure the room was created. - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -413,7 +407,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - event_creation_handler.send_nonmember_event(requester, event, context) + event_creation_handler.handle_new_client_event(requester, event, context) ) # Register a second user, which won't be be in the room (or even have an invite) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py new file mode 100644
index 0000000000..548038214b --- /dev/null +++ b/tests/handlers/test_saml.py
@@ -0,0 +1,265 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from mock import Mock + +import attr + +from synapse.api.errors import RedirectException + +from tests.test_utils import simple_async_mock +from tests.unittest import HomeserverTestCase, override_config + +# Check if we have the dependencies to run the tests. +try: + import saml2.config + from saml2.sigver import SigverError + + has_saml2 = True + + # pysaml2 can be installed and imported, but might not be able to find xmlsec1. + config = saml2.config.SPConfig() + try: + config.load({"metadata": {}}) + has_xmlsec1 = True + except SigverError: + has_xmlsec1 = False +except ImportError: + has_saml2 = False + has_xmlsec1 = False + +# These are a few constants that are used as config parameters in the tests. +BASE_URL = "https://synapse/" + + +@attr.s +class FakeAuthnResponse: + ava = attr.ib(type=dict) + assertions = attr.ib(type=list, factory=list) + in_response_to = attr.ib(type=Optional[str], default=None) + + +class TestMappingProvider: + def __init__(self, config, module): + pass + + @staticmethod + def parse_config(config): + return + + @staticmethod + def get_saml_attributes(config): + return {"uid"}, {"displayName"} + + def get_remote_user_id(self, saml_response, client_redirect_url): + return saml_response.ava["uid"] + + def saml_response_to_user_attributes( + self, saml_response, failures, client_redirect_url + ): + localpart = saml_response.ava["username"] + (str(failures) if failures else "") + return {"mxid_localpart": localpart, "displayname": None} + + +class TestRedirectMappingProvider(TestMappingProvider): + def saml_response_to_user_attributes( + self, saml_response, failures, client_redirect_url + ): + raise RedirectException(b"https://custom-saml-redirect/") + + +class SamlHandlerTestCase(HomeserverTestCase): + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + saml_config = { + "sp_config": {"metadata": {}}, + # Disable grandfathering. + "grandfathered_mxid_source_attribute": None, + "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, + } + + # Update this config with what's in the default config so that + # override_config works as expected. + saml_config.update(config.get("saml2_config", {})) + config["saml2_config"] = saml_config + + return config + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + self.handler = hs.get_saml_handler() + + # Reduce the number of attempts when generating MXIDs. + sso_handler = hs.get_sso_handler() + sso_handler._MAP_USERNAME_RETRIES = 3 + + return hs + + if not has_saml2: + skip = "Requires pysaml2" + elif not has_xmlsec1: + skip = "Requires xmlsec1" + + def test_map_saml_response_to_user(self): + """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # send a mocked-up SAML response to the callback + saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None + ) + + @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) + def test_map_saml_response_to_existing_user(self): + """Existing users can log in with SAML account.""" + store = self.hs.get_datastore() + self.get_success( + store.register_user(user_id="@test_user:test", password_hash=None) + ) + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # Map a user via SSO. + saml_response = FakeAuthnResponse( + {"uid": "tester", "mxid": ["test_user"], "username": "test_user"} + ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "", None + ) + + # Subsequent calls should map to the same mxid. + auth_handler.complete_sso_login.reset_mock() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "", None + ) + + def test_map_saml_response_to_invalid_localpart(self): + """If the mapping provider generates an invalid localpart it should be rejected.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # mock out the error renderer too + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + + saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), + ) + sso_handler.render_error.assert_called_once_with( + request, "mapping_error", "localpart is invalid: föö" + ) + auth_handler.complete_sso_login.assert_not_called() + + def test_map_saml_response_to_user_retries(self): + """The mapping provider can retry generating an MXID if the MXID is already in use.""" + + # stub out the auth handler and error renderer + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + + # register a user to occupy the first-choice MXID + store = self.hs.get_datastore() + self.get_success( + store.register_user(user_id="@test_user:test", password_hash=None) + ) + + # send the fake SAML response + saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), + ) + + # test_user is already taken, so test_user1 gets registered instead. + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user1:test", request, "", None + ) + auth_handler.complete_sso_login.reset_mock() + + # Register all of the potential mxids for a particular SAML username. + self.get_success( + store.register_user(user_id="@tester:test", password_hash=None) + ) + for i in range(1, 3): + self.get_success( + store.register_user(user_id="@tester%d:test" % i, password_hash=None) + ) + + # Now attempt to map to a username, this will fail since all potential usernames are taken. + saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"}) + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), + ) + sso_handler.render_error.assert_called_once_with( + request, + "mapping_error", + "Unable to generate a Matrix ID from the SSO response", + ) + auth_handler.complete_sso_login.assert_not_called() + + @override_config( + { + "saml2_config": { + "user_mapping_provider": { + "module": __name__ + ".TestRedirectMappingProvider" + }, + } + } + ) + def test_map_saml_response_redirect(self): + """Test a mapping provider that raises a RedirectException""" + + saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) + request = _mock_request() + e = self.get_failure( + self.handler._handle_authn_response(request, saml_response, ""), + RedirectException, + ) + self.assertEqual(e.value.location, b"https://custom-saml-redirect/") + + +def _mock_request(): + """Returns a mock which will stand in as a SynapseRequest""" + return Mock(spec=["getClientIP", "get_user_agent"]) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e178d7765b..e62586142e 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py
@@ -16,7 +16,7 @@ from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.handlers.sync import SyncConfig -from synapse.types import UserID +from synapse.types import UserID, create_requester import tests.unittest import tests.utils @@ -38,6 +38,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): user_id1 = "@user1:test" user_id2 = "@user2:test" sync_config = self._generate_sync_config(user_id1) + requester = create_requester(user_id1) self.reactor.advance(100) # So we get not 0 time self.auth_blocking._limit_usage_by_mau = True @@ -45,21 +46,26 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Check that the happy case does not throw errors self.get_success(self.store.upsert_monthly_active_user(user_id1)) - self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) + self.get_success( + self.sync_handler.wait_for_sync_for_user(requester, sync_config) + ) # Test that global lock works self.auth_blocking._hs_disabled = True e = self.get_failure( - self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + self.sync_handler.wait_for_sync_for_user(requester, sync_config), + ResourceLimitError, ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.auth_blocking._hs_disabled = False sync_config = self._generate_sync_config(user_id2) + requester = create_requester(user_id2) e = self.get_failure( - self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + self.sync_handler.wait_for_sync_for_user(requester, sync_config), + ResourceLimitError, ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 3fec09ea8a..96e5bdac4a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@ import json +from typing import Dict from mock import ANY, Mock, call from twisted.internet import defer +from twisted.web.resource import Resource from synapse.api.errors import AuthError +from synapse.federation.transport.server import TransportLayerServer from synapse.types import UserID, create_requester from tests import unittest from tests.test_utils import make_awaitable from tests.unittest import override_config -from tests.utils import register_federation_servlets # Some local users to test with U_APPLE = UserID.from_string("@apple:test") @@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content): class TypingNotificationsTestCase(unittest.HomeserverTestCase): - servlets = [register_federation_servlets] - def make_homeserver(self, reactor, clock): # we mock out the keyring so as to skip the authentication check on the # federation API call. @@ -65,40 +65,23 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["put_json"]) mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) - datastores = Mock() - datastores.main = Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_last_successful_stream_ordering", - "get_destination_retry_timings", - "get_devices_by_remote", - "maybe_store_room_on_invite", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - "get_device_updates_by_remote", - "get_room_max_stream_ordering", - ] - ) - # the tests assume that we are starting at unix time 1000 reactor.pump((1000,)) hs = self.setup_test_homeserver( notifier=Mock(), - http_client=mock_federation_client, + federation_http_client=mock_federation_client, keyring=mock_keyring, replication_streams={}, ) - hs.datastores = datastores - return hs + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_matrix/federation"] = TransportLayerServer(self.hs) + return d + def prepare(self, reactor, clock, hs): mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event @@ -114,16 +97,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "retry_interval": 0, "failure_ts": None, } - self.datastore.get_destination_retry_timings.return_value = defer.succeed( - retry_timings_res + self.datastore.get_destination_retry_timings = Mock( + return_value=defer.succeed(retry_timings_res) ) - self.datastore.get_device_updates_by_remote.return_value = make_awaitable( - (0, []) + self.datastore.get_device_updates_by_remote = Mock( + return_value=make_awaitable((0, [])) ) - self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable( - None + self.datastore.get_destination_last_successful_stream_ordering = Mock( + return_value=make_awaitable(None) ) def get_received_txn_response(*args): @@ -145,17 +128,19 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room - def get_users_in_room(room_id): - return defer.succeed({str(u) for u in self.room_members}) + async def get_users_in_room(room_id): + return {str(u) for u in self.room_members} self.datastore.get_users_in_room = get_users_in_room - self.datastore.get_user_directory_stream_pos.side_effect = ( - # we deliberately return a non-None stream pos to avoid doing an initial_spam - lambda: make_awaitable(1) + self.datastore.get_user_directory_stream_pos = Mock( + side_effect=( + # we deliberately return a non-None stream pos to avoid doing an initial_spam + lambda: make_awaitable(1) + ) ) - self.datastore.get_current_state_deltas.return_value = (0, None) + self.datastore.get_current_state_deltas = Mock(return_value=(0, None)) self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( @@ -212,7 +197,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - put_json = self.hs.get_http_client().put_json + put_json = self.hs.get_federation_http_client().put_json put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", @@ -235,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - (request, channel) = self.make_request( + channel = self.make_request( "PUT", "/_matrix/federation/v1/send/1000000", _make_edu_transaction_json( @@ -248,7 +233,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ), federation_auth_origin=b"farm", ) - self.render(request) self.assertEqual(channel.code, 200) self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) @@ -291,7 +275,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - put_json = self.hs.get_http_client().put_json + put_json = self.hs.get_federation_http_client().put_json put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 87be94111f..9c886d671a 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -54,6 +54,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT ) ) + regular_user_id = "@regular:test" + self.get_success( + self.store.register_user(user_id=regular_user_id, password_hash=None) + ) self.get_success( self.handler.handle_local_profile_change(support_user_id, None) @@ -63,13 +67,47 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): display_name = "display_name" profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) - regular_user_id = "@regular:test" self.get_success( self.handler.handle_local_profile_change(regular_user_id, profile_info) ) profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) self.assertTrue(profile["display_name"] == display_name) + def test_handle_local_profile_change_with_deactivated_user(self): + # create user + r_user_id = "@regular:test" + self.get_success( + self.store.register_user(user_id=r_user_id, password_hash=None) + ) + + # update profile + display_name = "Regular User" + profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) + self.get_success( + self.handler.handle_local_profile_change(r_user_id, profile_info) + ) + + # profile is in directory + profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + self.assertTrue(profile["display_name"] == display_name) + + # deactivate user + self.get_success(self.store.set_user_deactivated_status(r_user_id, True)) + self.get_success(self.handler.handle_user_deactivated(r_user_id)) + + # profile is not in directory + profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + self.assertTrue(profile is None) + + # update profile after deactivation + self.get_success( + self.handler.handle_local_profile_change(r_user_id, profile_info) + ) + + # profile is furthermore not in directory + profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + self.assertTrue(profile is None) + def test_handle_user_deactivated_support_user(self): s_user_id = "@support:test" self.get_success( @@ -270,7 +308,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): spam_checker = self.hs.get_spam_checker() class AllowAll: - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -283,7 +321,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that filters all users. class BlockAll: - def check_username_for_spam(self, user_profile): + async def check_username_for_spam(self, user_profile): # All users are spammy. return True @@ -534,18 +572,16 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): self.helper.join(room, user=u2) # Assert user directory is not empty - request, channel = self.make_request( + channel = self.make_request( "POST", b"user_directory/search", b'{"search_term":"user2"}' ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) > 0) # Disable user directory and check search returns nothing self.config.user_directory_search_enabled = False - request, channel = self.make_request( + channel = self.make_request( "POST", b"user_directory/search", b'{"search_term":"user2"}' ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 8b5ad4574f..4e51839d0f 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -17,6 +17,7 @@ import logging from mock import Mock import treq +from netaddr import IPSet from service_identity import VerificationError from zope.interface import implementer @@ -35,6 +36,7 @@ from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.srv_resolver import Server from synapse.http.federation.well_known_resolver import ( + WELL_KNOWN_MAX_SIZE, WellKnownResolver, _cache_period_from_headers, ) @@ -103,6 +105,7 @@ class MatrixFederationAgentTests(unittest.TestCase): reactor=self.reactor, tls_client_options_factory=self.tls_factory, user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. + ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, ) @@ -736,6 +739,7 @@ class MatrixFederationAgentTests(unittest.TestCase): reactor=self.reactor, tls_client_options_factory=tls_factory, user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. + ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( self.reactor, @@ -1104,6 +1108,32 @@ class MatrixFederationAgentTests(unittest.TestCase): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) + def test_well_known_too_large(self): + """A well-known query that returns a result which is too large should be rejected.""" + self.reactor.lookups["testserv"] = "1.2.3.4" + + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) + + # there should be an attempt to connect on port 443 for the .well-known + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 443) + + self._handle_well_known_connection( + client_factory, + expected_sni=b"testserv", + response_headers={b"Cache-Control": b"max-age=1000"}, + content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }', + ) + + # The result is sucessful, but disabled delegation. + r = self.successResultOf(fetch_d) + self.assertIsNone(r.delegated_server) + def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails. """ diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 62d36c2906..453391a5a5 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py
@@ -17,6 +17,7 @@ from synapse.http.additional_resource import AdditionalResource from synapse.http.server import respond_with_json +from tests.server import FakeSite, make_request from tests.unittest import HomeserverTestCase @@ -43,20 +44,18 @@ class AdditionalResourceTests(HomeserverTestCase): def test_async(self): handler = _AsyncTestCustomEndpoint({}, None).handle_request - self.resource = AdditionalResource(self.hs, handler) + resource = AdditionalResource(self.hs, handler) - request, channel = self.make_request("GET", "/") - self.render(request) + channel = make_request(self.reactor, FakeSite(resource), "GET", "/") - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) def test_sync(self): handler = _SyncTestCustomEndpoint({}, None).handle_request - self.resource = AdditionalResource(self.hs, handler) + resource = AdditionalResource(self.hs, handler) - request, channel = self.make_request("GET", "/") - self.render(request) + channel = make_request(self.reactor, FakeSite(resource), "GET", "/") - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 22abf76515..9a56e1c14a 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py
@@ -15,12 +15,14 @@ import logging import treq +from netaddr import IPSet from twisted.internet import interfaces # noqa: F401 from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.web.http import HTTPChannel +from synapse.http.client import BlacklistingReactorWrapper from synapse.http.proxyagent import ProxyAgent from tests.http import TestServerTLSConnectionFactory, get_test_https_policy @@ -292,6 +294,134 @@ class MatrixFederationAgentTests(TestCase): body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") + def test_http_request_via_proxy_with_blacklist(self): + # The blacklist includes the configured proxy IP. + agent = ProxyAgent( + BlacklistingReactorWrapper( + self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"]) + ), + self.reactor, + http_proxy=b"proxy.com:8888", + ) + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"http://test.com") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 8888) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"http://test.com") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + + def test_https_request_via_proxy_with_blacklist(self): + # The blacklist includes the configured proxy IP. + agent = ProxyAgent( + BlacklistingReactorWrapper( + self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"]) + ), + self.reactor, + contextFactory=get_test_https_policy(), + https_proxy=b"proxy.com", + ) + + self.reactor.lookups["proxy.com"] = "1.2.3.5" + d = agent.request(b"GET", b"https://test.com/abc") + + # there should be a pending TCP connection + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.5") + self.assertEqual(port, 1080) + + # make a test HTTP server, and wire up the client + proxy_server = self._make_connection( + client_factory, _get_test_protocol_factory() + ) + + # fish the transports back out so that we can do the old switcheroo + s2c_transport = proxy_server.transport + client_protocol = s2c_transport.other + c2s_transport = client_protocol.transport + + # the FakeTransport is async, so we need to pump the reactor + self.reactor.advance(0) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"test.com:443") + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + # this just stops the http Request trying to do a chunked response + # request.setHeader(b"Content-Length", b"0") + request.finish() + + # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel + ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) + ssl_protocol = ssl_factory.buildProtocol(None) + http_server = ssl_protocol.wrappedProtocol + + ssl_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, ssl_protocol) + ) + c2s_transport.other = ssl_protocol + + self.reactor.advance(0) + + server_name = ssl_protocol._tlsConnection.get_servername() + expected_sni = b"test.com" + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + # now there should be a pending request + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/abc") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) + request.write(b"result") + request.finish() + + self.reactor.advance(0) + + resp = self.successResultOf(d) + body = self.successResultOf(treq.content(resp)) + self.assertEqual(body, b"result") + def _wrap_server_factory_for_tls(factory, sanlist=None): """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
index e69de29bb2..a58d51441c 100644 --- a/tests/logging/__init__.py +++ b/tests/logging/__init__.py
@@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + + +class LoggerCleanupMixin: + def get_logger(self, handler): + """ + Attach a handler to a logger and add clean-ups to remove revert this. + """ + # Create a logger and add the handler to it. + logger = logging.getLogger(__name__) + logger.addHandler(handler) + + # Ensure the logger actually logs something. + logger.setLevel(logging.INFO) + + # Ensure the logger gets cleaned-up appropriately. + self.addCleanup(logger.removeHandler, handler) + self.addCleanup(logger.setLevel, logging.NOTSET) + + return logger diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py new file mode 100644
index 0000000000..4bc27a1d7d --- /dev/null +++ b/tests/logging/test_remote_handler.py
@@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.test.proto_helpers import AccumulatingProtocol + +from synapse.logging import RemoteHandler + +from tests.logging import LoggerCleanupMixin +from tests.server import FakeTransport, get_clock +from tests.unittest import TestCase + + +def connect_logging_client(reactor, client_id): + # This is essentially tests.server.connect_client, but disabling autoflush on + # the client transport. This is necessary to avoid an infinite loop due to + # sending of data via the logging transport causing additional logs to be + # written. + factory = reactor.tcpClients.pop(client_id)[2] + client = factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, reactor)) + client.makeConnection(FakeTransport(server, reactor, autoflush=False)) + + return client, server + + +class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): + def setUp(self): + self.reactor, _ = get_clock() + + def test_log_output(self): + """ + The remote handler delivers logs over TCP. + """ + handler = RemoteHandler("127.0.0.1", 9000, _reactor=self.reactor) + logger = self.get_logger(handler) + + logger.info("Hello there, %s!", "wally") + + # Trigger the connection + client, server = connect_logging_client(self.reactor, 0) + + # Trigger data being sent + client.transport.flush() + + # One log message, with a single trailing newline + logs = server.data.decode("utf8").splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(server.data.count(b"\n"), 1) + + # Ensure the data passed through properly. + self.assertEqual(logs[0], "Hello there, wally!") + + def test_log_backpressure_debug(self): + """ + When backpressure is hit, DEBUG logs will be shed. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 7): + logger.info("info %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # Only the 7 infos made it through, the debugs were elided + logs = server.data.splitlines() + self.assertEqual(len(logs), 7) + self.assertNotIn(b"debug", server.data) + + def test_log_backpressure_info(self): + """ + When backpressure is hit, DEBUG and INFO logs will be shed. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 10): + logger.warning("warn %s" % (i,)) + + # Send a bunch of info messages + for i in range(0, 3): + logger.info("info %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # The 10 warnings made it through, the debugs and infos were elided + logs = server.data.splitlines() + self.assertEqual(len(logs), 10) + self.assertNotIn(b"debug", server.data) + self.assertNotIn(b"info", server.data) + + def test_log_backpressure_cut_middle(self): + """ + When backpressure is hit, and no more DEBUG and INFOs cannot be culled, + it will cut the middle messages out. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send a bunch of useful messages + for i in range(0, 20): + logger.warning("warn %s" % (i,)) + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # The first five and last five warnings made it through, the debugs and + # infos were elided + logs = server.data.decode("utf8").splitlines() + self.assertEqual( + ["warn %s" % (i,) for i in range(5)] + + ["warn %s" % (i,) for i in range(15, 20)], + logs, + ) + + def test_cancel_connection(self): + """ + Gracefully handle the connection being cancelled. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send a message. + logger.info("Hello there, %s!", "wally") + + # Do not accept the connection and shutdown. This causes the pending + # connection to be cancelled (and should not raise any exceptions). + handler.close() diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py deleted file mode 100644
index d36f5f426c..0000000000 --- a/tests/logging/test_structured.py +++ /dev/null
@@ -1,214 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import os.path -import shutil -import sys -import textwrap - -from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile - -from synapse.config.logger import setup_logging -from synapse.logging._structured import setup_structured_logging -from synapse.logging.context import LoggingContext - -from tests.unittest import DEBUG, HomeserverTestCase - - -class FakeBeginner: - def beginLoggingTo(self, observers, **kwargs): - self.observers = observers - - -class StructuredLoggingTestBase: - """ - Test base that registers a cleanup handler to reset the stdlib log handler - to 'unset'. - """ - - def prepare(self, reactor, clock, hs): - def _cleanup(): - logging.getLogger("synapse").setLevel(logging.NOTSET) - - self.addCleanup(_cleanup) - - -class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase): - """ - Tests for Synapse's structured logging support. - """ - - def test_output_to_json_round_trip(self): - """ - Synapse logs can be outputted to JSON and then read back again. - """ - temp_dir = self.mktemp() - os.mkdir(temp_dir) - self.addCleanup(shutil.rmtree, temp_dir) - - json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json")) - - log_config = { - "drains": {"jsonfile": {"type": "file_json", "location": json_log_file}} - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Read the log file and check it has the event we sent - with open(json_log_file, "r") as f: - logged_events = list(eventsFromJSONLogFile(f)) - self.assertEqual(len(logged_events), 1) - - # The event pulled from the file should render fine - self.assertEqual( - eventAsText(logged_events[0], includeTimestamp=False), - "[tests.logging.test_structured#info] Hello there, wally!", - ) - - def test_output_to_text(self): - """ - Synapse logs can be outputted to text. - """ - temp_dir = self.mktemp() - os.mkdir(temp_dir) - self.addCleanup(shutil.rmtree, temp_dir) - - log_file = os.path.abspath(os.path.join(temp_dir, "out.log")) - - log_config = {"drains": {"file": {"type": "file", "location": log_file}}} - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Read the log file and check it has the event we sent - with open(log_file, "r") as f: - logged_events = f.read().strip().split("\n") - self.assertEqual(len(logged_events), 1) - - # The event pulled from the file should render fine - self.assertTrue( - logged_events[0].endswith( - " - tests.logging.test_structured - INFO - None - Hello there, wally!" - ) - ) - - def test_collects_logcontext(self): - """ - Test that log outputs have the attached logging context. - """ - log_config = {"drains": {}} - - # Begin the logger with our config - beginner = FakeBeginner() - publisher = setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - logs = [] - - publisher.addObserver(logs.append) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - - with LoggingContext("testcontext", request="somereq"): - logger.info("Hello there, {name}!", name="steve") - - self.assertEqual(len(logs), 1) - self.assertEqual(logs[0]["request"], "somereq") - - -class StructuredLoggingConfigurationFileTestCase( - StructuredLoggingTestBase, HomeserverTestCase -): - def make_homeserver(self, reactor, clock): - - tempdir = self.mktemp() - os.mkdir(tempdir) - log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml")) - self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log")) - - config = self.default_config() - config["log_config"] = log_config_file - - with open(log_config_file, "w") as f: - f.write( - textwrap.dedent( - """\ - structured: true - - drains: - file: - type: file_json - location: %s - """ - % (self.homeserver_log,) - ) - ) - - self.addCleanup(self._sys_cleanup) - - return self.setup_test_homeserver(config=config) - - def _sys_cleanup(self): - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - - # Do not remove! We need the logging system to be set other than WARNING. - @DEBUG - def test_log_output(self): - """ - When a structured logging config is given, Synapse will use it. - """ - beginner = FakeBeginner() - publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner) - - # Make a logger and send an event - logger = Logger(namespace="tests.logging.test_structured", observer=publisher) - - with LoggingContext("testcontext", request="somereq"): - logger.info("Hello there, {name}!", name="steve") - - with open(self.homeserver_log, "r") as f: - logged_events = [ - eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f) - ] - - logs = "\n".join(logged_events) - self.assertTrue("***** STARTING SERVER *****" in logs) - self.assertTrue("Hello there, steve!" in logs) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 4cf81f7128..48a74e2eee 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py
@@ -14,221 +14,124 @@ # limitations under the License. import json -from collections import Counter +import logging +from io import StringIO -from twisted.logger import Logger +from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter +from synapse.logging.context import LoggingContext, LoggingContextFilter -from synapse.logging._structured import setup_structured_logging +from tests.logging import LoggerCleanupMixin +from tests.unittest import TestCase -from tests.server import connect_client -from tests.unittest import HomeserverTestCase -from .test_structured import FakeBeginner, StructuredLoggingTestBase +class TerseJsonTestCase(LoggerCleanupMixin, TestCase): + def setUp(self): + self.output = StringIO() + def get_log_line(self): + # One log message, with a single trailing newline. + data = self.output.getvalue() + logs = data.splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(data.count("\n"), 1) + return json.loads(logs[0]) -class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase): - def test_log_output(self): + def test_terse_json_output(self): """ - The Terse JSON outputter delivers simplified structured logs over TCP. + The Terse JSON formatter converts log messages to JSON. """ - log_config = { - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - } - } - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - logger = Logger( - namespace="tests.logging.test_terse_json", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Trigger the connection - self.pump() + handler = logging.StreamHandler(self.output) + handler.setFormatter(TerseJsonFormatter()) + logger = self.get_logger(handler) - _, server = connect_client(self.reactor, 0) + logger.info("Hello there, %s!", "wally") - # Trigger data being sent - self.pump() - - # One log message, with a single trailing newline - logs = server.data.decode("utf8").splitlines() - self.assertEqual(len(logs), 1) - self.assertEqual(server.data.count(b"\n"), 1) - - log = json.loads(logs[0]) + log = self.get_log_line() # The terse logger should give us these keys. expected_log_keys = [ "log", "time", "level", - "log_namespace", - "request", - "scope", - "server_name", - "name", + "namespace", ] - self.assertEqual(set(log.keys()), set(expected_log_keys)) - - # It contains the data we expect. - self.assertEqual(log["name"], "wally") + self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") - def test_log_backpressure_debug(self): + def test_extra_data(self): """ - When backpressure is hit, DEBUG logs will be shed. + Additional information can be included in the structured logging. """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) + handler = logging.StreamHandler(self.output) + handler.setFormatter(TerseJsonFormatter()) + logger = self.get_logger(handler) - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] + logger.info( + "Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True} ) - # Send some debug messages - for i in range(0, 3): - logger.debug("debug %s" % (i,)) - - # Send a bunch of useful messages - for i in range(0, 7): - logger.info("test message %s" % (i,)) - - # The last debug message pushes it past the maximum buffer - logger.debug("too much debug") + log = self.get_log_line() - # Allow the reconnection - _, server = connect_client(self.reactor, 0) - self.pump() + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "time", + "level", + "namespace", + # The additional keys given via extra. + "foo", + "int", + "bool", + ] + self.assertCountEqual(log.keys(), expected_log_keys) - # Only the 7 infos made it through, the debugs were elided - logs = server.data.splitlines() - self.assertEqual(len(logs), 7) + # Check the values of the extra fields. + self.assertEqual(log["foo"], "bar") + self.assertEqual(log["int"], 3) + self.assertIs(log["bool"], True) - def test_log_backpressure_info(self): + def test_json_output(self): """ - When backpressure is hit, DEBUG and INFO logs will be shed. + The Terse JSON formatter converts log messages to JSON. """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) + handler = logging.StreamHandler(self.output) + handler.setFormatter(JsonFormatter()) + logger = self.get_logger(handler) - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] - ) - - # Send some debug messages - for i in range(0, 3): - logger.debug("debug %s" % (i,)) - - # Send a bunch of useful messages - for i in range(0, 10): - logger.warn("test warn %s" % (i,)) - - # Send a bunch of info messages - for i in range(0, 3): - logger.info("test message %s" % (i,)) - - # The last debug message pushes it past the maximum buffer - logger.debug("too much debug") + logger.info("Hello there, %s!", "wally") - # Allow the reconnection - client, server = connect_client(self.reactor, 0) - self.pump() + log = self.get_log_line() - # The 10 warnings made it through, the debugs and infos were elided - logs = list(map(json.loads, server.data.decode("utf8").splitlines())) - self.assertEqual(len(logs), 10) - - self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "level", + "namespace", + ] + self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") - def test_log_backpressure_cut_middle(self): + def test_with_context(self): """ - When backpressure is hit, and no more DEBUG and INFOs cannot be culled, - it will cut the middle messages out. + The logging context should be added to the JSON response. """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) + handler = logging.StreamHandler(self.output) + handler.setFormatter(JsonFormatter()) + handler.addFilter(LoggingContextFilter()) + logger = self.get_logger(handler) - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] - ) + with LoggingContext(request="test"): + logger.info("Hello there, %s!", "wally") - # Send a bunch of useful messages - for i in range(0, 20): - logger.warn("test warn", num=i) + log = self.get_log_line() - # Allow the reconnection - client, server = connect_client(self.reactor, 0) - self.pump() - - # The first five and last five warnings made it through, the debugs and - # infos were elided - logs = list(map(json.loads, server.data.decode("utf8").splitlines())) - self.assertEqual(len(logs), 10) - self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) - self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs]) + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "level", + "namespace", + "request", + ] + self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") + self.assertEqual(log["request"], "test") diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 04de0b9dbe..27206ca3db 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -12,16 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock -from synapse.module_api import ModuleApi +from synapse.events import EventBase +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.types import create_requester from tests.unittest import HomeserverTestCase class ModuleApiTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() - self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler()) + self.module_api = homeserver.get_module_api() + self.event_creation_handler = homeserver.get_event_creation_handler() def test_can_register_user(self): """Tests that an external module can register a user""" @@ -52,3 +63,142 @@ class ModuleApiTestCase(HomeserverTestCase): # Check that the displayname was assigned displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + + def test_sending_events_into_room(self): + """Tests that a module can send events into a room""" + # Mock out create_and_send_nonmember_event to check whether events are being sent + self.event_creation_handler.create_and_send_nonmember_event = Mock( + spec=[], + side_effect=self.event_creation_handler.create_and_send_nonmember_event, + ) + + # Create a user and room to play with + user_id = self.register_user("summer", "monkey") + tok = self.login("summer", "monkey") + room_id = self.helper.create_room_as(user_id, tok=tok) + + # Create and send a non-state event + content = {"body": "I am a puppet", "msgtype": "m.text"} + event_dict = { + "room_id": room_id, + "type": "m.room.message", + "content": content, + "sender": user_id, + } + event = self.get_success( + self.module_api.create_and_send_event_into_room(event_dict) + ) # type: EventBase + self.assertEqual(event.sender, user_id) + self.assertEqual(event.type, "m.room.message") + self.assertEqual(event.room_id, room_id) + self.assertFalse(hasattr(event, "state_key")) + self.assertDictEqual(event.content, content) + + expected_requester = create_requester( + user_id, authenticated_entity=self.hs.hostname + ) + + # Check that the event was sent + self.event_creation_handler.create_and_send_nonmember_event.assert_called_with( + expected_requester, event_dict, ratelimit=False, ignore_shadow_ban=True, + ) + + # Create and send a state event + content = { + "events_default": 0, + "users": {user_id: 100}, + "state_default": 50, + "users_default": 0, + "events": {"test.event.type": 25}, + } + event_dict = { + "room_id": room_id, + "type": "m.room.power_levels", + "content": content, + "sender": user_id, + "state_key": "", + } + event = self.get_success( + self.module_api.create_and_send_event_into_room(event_dict) + ) # type: EventBase + self.assertEqual(event.sender, user_id) + self.assertEqual(event.type, "m.room.power_levels") + self.assertEqual(event.room_id, room_id) + self.assertEqual(event.state_key, "") + self.assertDictEqual(event.content, content) + + # Check that the event was sent + self.event_creation_handler.create_and_send_nonmember_event.assert_called_with( + expected_requester, + { + "type": "m.room.power_levels", + "content": content, + "room_id": room_id, + "sender": user_id, + "state_key": "", + }, + ratelimit=False, + ignore_shadow_ban=True, + ) + + # Check that we can't send membership events + content = { + "membership": "leave", + } + event_dict = { + "room_id": room_id, + "type": "m.room.member", + "content": content, + "sender": user_id, + "state_key": user_id, + } + self.get_failure( + self.module_api.create_and_send_event_into_room(event_dict), Exception + ) + + def test_public_rooms(self): + """Tests that a room can be added and removed from the public rooms list, + as well as have its public rooms directory state queried. + """ + # Create a user and room to play with + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + room_id = self.helper.create_room_as(user_id, tok=tok) + + # The room should not currently be in the public rooms directory + is_in_public_rooms = self.get_success( + self.module_api.public_room_list_manager.room_is_in_public_room_list( + room_id + ) + ) + self.assertFalse(is_in_public_rooms) + + # Let's try adding it to the public rooms directory + self.get_success( + self.module_api.public_room_list_manager.add_room_to_public_room_list( + room_id + ) + ) + + # And checking whether it's in there... + is_in_public_rooms = self.get_success( + self.module_api.public_room_list_manager.room_is_in_public_room_list( + room_id + ) + ) + self.assertTrue(is_in_public_rooms) + + # Let's remove it again + self.get_success( + self.module_api.public_room_list_manager.remove_room_from_public_room_list( + room_id + ) + ) + + # Should be gone + is_in_public_rooms = self.get_success( + self.module_api.public_room_list_manager.room_is_in_public_room_list( + room_id + ) + ) + self.assertFalse(is_in_public_rooms) diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 3224568640..961bf09de9 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py
@@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.pusher = self.get_success( self.hs.get_pusherpool().add_pusher( @@ -131,6 +131,35 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about that message self._check_for_mail() + def test_invite_sends_email(self): + # Create a room and invite the user to it + room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) + self.helper.invite( + room=room, + src=self.others[0].id, + tok=self.others[0].token, + targ=self.user_id, + ) + + # We should get emailed about the invite + self._check_for_mail() + + def test_invite_to_empty_room_sends_email(self): + # Create a room and invite the user to it + room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) + self.helper.invite( + room=room, + src=self.others[0].id, + tok=self.others[0].token, + targ=self.user_id, + ) + + # Then have the original user leave + self.helper.leave(room, self.others[0].id, tok=self.others[0].token) + + # We should get emailed about the invite + self._check_for_mail() + def test_multiple_members_email(self): # We want to test multiple notifications, so we pause processing of push # while we send messages. @@ -158,8 +187,21 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about those messages self._check_for_mail() + def test_encrypted_message(self): + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.invite( + room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id + ) + self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token) + + # The other user sends some messages + self.helper.send_event(room, "m.room.encrypted", {}, tok=self.others[0].token) + + # We should get emailed about that message + self._check_for_mail() + def _check_for_mail(self): - "Check that the user receives an email notification" + """Check that the user receives an email notification""" # Get the stream ordering before it gets sent pushers = self.get_success( @@ -167,7 +209,7 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - last_stream_ordering = pushers[0]["last_stream_ordering"] + last_stream_ordering = pushers[0].last_stream_ordering # Advance time a bit, so the pusher will register something has happened self.pump(10) @@ -178,7 +220,7 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) # One email was attempted to be sent self.assertEqual(len(self.email_attempts), 1) @@ -196,4 +238,4 @@ class EmailPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index b567868b02..60f0820cff 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py
@@ -12,16 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from mock import Mock from twisted.internet.defer import Deferred import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable +from synapse.push import PusherConfigException from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import receipts -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class HTTPPusherTests(HomeserverTestCase): @@ -29,10 +30,16 @@ class HTTPPusherTests(HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, + receipts.register_servlets, ] user_id = True hijack_auth = False + def default_config(self): + config = super().default_config() + config["start_pushers"] = True + return config + def make_homeserver(self, reactor, clock): self.push_attempts = [] @@ -45,13 +52,49 @@ class HTTPPusherTests(HomeserverTestCase): m.post_json_get_json = post_json_get_json - config = self.default_config() - config["start_pushers"] = True - - hs = self.setup_test_homeserver(config=config, proxied_http_client=m) + hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m) return hs + def test_invalid_configuration(self): + """Invalid push configurations should be rejected.""" + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + + def test_data(data): + self.get_failure( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data=data, + ), + PusherConfigException, + ) + + # Data must be provided with a URL. + test_data(None) + test_data({}) + test_data({"url": 1}) + # A bare domain name isn't accepted. + test_data({"url": "example.com"}) + # A URL without a path isn't accepted. + test_data({"url": "http://example.com"}) + # A url with an incorrect path isn't accepted. + test_data({"url": "http://example.com/foo"}) + def test_sends_http(self): """ The HTTP pusher will send pushes for each message to a HTTP endpoint @@ -69,7 +112,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -81,7 +124,7 @@ class HTTPPusherTests(HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "http://example.com/_matrix/push/v1/notify"}, ) ) @@ -101,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - last_stream_ordering = pushers[0]["last_stream_ordering"] + last_stream_ordering = pushers[0].last_stream_ordering # Advance time a bit, so the pusher will register something has happened self.pump() @@ -112,11 +155,13 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) # One push was attempted to be sent -- it'll be the first message self.assertEqual(len(self.push_attempts), 1) - self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual( self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!" ) @@ -131,12 +176,14 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) - last_stream_ordering = pushers[0]["last_stream_ordering"] + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) + last_stream_ordering = pushers[0].last_stream_ordering # Now it'll try and send the second push message, which will be the second one self.assertEqual(len(self.push_attempts), 2) - self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual( + self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual( self.push_attempts[1][2]["notification"]["content"]["body"], "There!" ) @@ -151,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) def test_sends_high_priority_for_encrypted(self): """ @@ -181,7 +228,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -193,7 +240,7 @@ class HTTPPusherTests(HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "http://example.com/_matrix/push/v1/notify"}, ) ) @@ -229,7 +276,9 @@ class HTTPPusherTests(HomeserverTestCase): # Check our push made it with high priority self.assertEqual(len(self.push_attempts), 1) - self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") # Add yet another person — we want to make this room not a 1:1 @@ -267,7 +316,9 @@ class HTTPPusherTests(HomeserverTestCase): # Advance time a bit, so the pusher will register something has happened self.pump() self.assertEqual(len(self.push_attempts), 2) - self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual( + self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") def test_sends_high_priority_for_one_to_one_only(self): @@ -297,7 +348,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -309,7 +360,7 @@ class HTTPPusherTests(HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "http://example.com/_matrix/push/v1/notify"}, ) ) @@ -325,7 +376,9 @@ class HTTPPusherTests(HomeserverTestCase): # Check our push made it with high priority — this is a one-to-one room self.assertEqual(len(self.push_attempts), 1) - self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") # Yet another user joins @@ -344,7 +397,9 @@ class HTTPPusherTests(HomeserverTestCase): # Advance time a bit, so the pusher will register something has happened self.pump() self.assertEqual(len(self.push_attempts), 2) - self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual( + self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify" + ) # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") @@ -379,7 +434,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -391,7 +446,7 @@ class HTTPPusherTests(HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "http://example.com/_matrix/push/v1/notify"}, ) ) @@ -407,7 +462,9 @@ class HTTPPusherTests(HomeserverTestCase): # Check our push made it with high priority self.assertEqual(len(self.push_attempts), 1) - self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") # Send another event, this time with no mention @@ -416,7 +473,9 @@ class HTTPPusherTests(HomeserverTestCase): # Advance time a bit, so the pusher will register something has happened self.pump() self.assertEqual(len(self.push_attempts), 2) - self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual( + self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify" + ) # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") @@ -452,7 +511,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -464,7 +523,7 @@ class HTTPPusherTests(HomeserverTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "example.com"}, + data={"url": "http://example.com/_matrix/push/v1/notify"}, ) ) @@ -484,7 +543,9 @@ class HTTPPusherTests(HomeserverTestCase): # Check our push made it with high priority self.assertEqual(len(self.push_attempts), 1) - self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") # Send another event, this time as someone without the power of @room @@ -495,7 +556,169 @@ class HTTPPusherTests(HomeserverTestCase): # Advance time a bit, so the pusher will register something has happened self.pump() self.assertEqual(len(self.push_attempts), 2) - self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual( + self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify" + ) # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + + def test_push_unread_count_group_by_room(self): + """ + The HTTP pusher will group unread count by number of unread rooms. + """ + # Carry out common push count tests and setup + self._test_push_unread_count() + + # Carry out our option-value specific test + # + # This push should still only contain an unread count of 1 (for 1 unread room) + self.assertEqual( + self.push_attempts[5][2]["notification"]["counts"]["unread"], 1 + ) + + @override_config({"push": {"group_unread_count_by_room": False}}) + def test_push_unread_count_message_count(self): + """ + The HTTP pusher will send the total unread message count. + """ + # Carry out common push count tests and setup + self._test_push_unread_count() + + # Carry out our option-value specific test + # + # We're counting every unread message, so there should now be 4 since the + # last read receipt + self.assertEqual( + self.push_attempts[5][2]["notification"]["counts"]["unread"], 4 + ) + + def _test_push_unread_count(self): + """ + Tests that the correct unread count appears in sent push notifications + + Note that: + * Sending messages will cause push notifications to go out to relevant users + * Sending a read receipt will cause a "badge update" notification to go out to + the user that sent the receipt + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("other_user", "pass") + other_access_token = self.login("other_user", "pass") + + # Create a room (as other_user) + room_id = self.helper.create_room_as(other_user_id, tok=other_access_token) + + # The user to get notified joins + self.helper.join(room=room_id, user=user_id, tok=access_token) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "http://example.com/_matrix/push/v1/notify"}, + ) + ) + + # Send a message + response = self.helper.send( + room_id, body="Hello there!", tok=other_access_token + ) + # To get an unread count, the user who is getting notified has to have a read + # position in the room. We'll set the read position to this event in a moment + first_message_event_id = response["event_id"] + + # Advance time a bit (so the pusher will register something has happened) and + # make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) + + # Check that the unread count for the room is 0 + # + # The unread count is zero as the user has no read receipt in the room yet + self.assertEqual( + self.push_attempts[0][2]["notification"]["counts"]["unread"], 0 + ) + + # Now set the user's read receipt position to the first event + # + # This will actually trigger a new notification to be sent out so that + # even if the user does not receive another message, their unread + # count goes down + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id), + {}, + access_token=access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Advance time and make the push succeed + self.push_attempts[1][0].callback({}) + self.pump() + + # Unread count is still zero as we've read the only message in the room + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual( + self.push_attempts[1][2]["notification"]["counts"]["unread"], 0 + ) + + # Send another message + self.helper.send( + room_id, body="How's the weather today?", tok=other_access_token + ) + + # Advance time and make the push succeed + self.push_attempts[2][0].callback({}) + self.pump() + + # This push should contain an unread count of 1 as there's now been one + # message since our last read receipt + self.assertEqual(len(self.push_attempts), 3) + self.assertEqual( + self.push_attempts[2][2]["notification"]["counts"]["unread"], 1 + ) + + # Since we're grouping by room, sending more messages shouldn't increase the + # unread count, as they're all being sent in the same room + self.helper.send(room_id, body="Hello?", tok=other_access_token) + + # Advance time and make the push succeed + self.pump() + self.push_attempts[3][0].callback({}) + + self.helper.send(room_id, body="Hello??", tok=other_access_token) + + # Advance time and make the push succeed + self.pump() + self.push_attempts[4][0].callback({}) + + self.helper.send(room_id, body="HELLO???", tok=other_access_token) + + # Advance time and make the push succeed + self.pump() + self.push_attempts[5][0].callback({}) + + self.assertEqual(len(self.push_attempts), 6) diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index ae60874ec3..3379189785 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -12,23 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import attr from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime +from twisted.internet.protocol import Protocol from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel +from twisted.web.resource import Resource from synapse.app.generic_worker import ( GenericWorkerReplicationHandler, GenericWorkerServer, ) from synapse.http.server import JsonResource -from synapse.http.site import SynapseRequest -from synapse.replication.http import ReplicationRestResource, streams +from synapse.http.site import SynapseRequest, SynapseSite +from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -36,7 +37,12 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeTransport, render +from tests.server import FakeTransport + +try: + import hiredis +except ImportError: + hiredis = None logger = logging.getLogger(__name__) @@ -44,9 +50,10 @@ logger = logging.getLogger(__name__) class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" - servlets = [ - streams.register_servlets, - ] + # hiredis is an optional dependency so we don't want to require it for running + # the tests. + if not hiredis: + skip = "Requires hiredis" def prepare(self, reactor, clock, hs): # build a replication server @@ -57,8 +64,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Make a new HomeServer object for the worker self.reactor.lookups["testserv"] = "1.2.3.4" self.worker_hs = self.setup_test_homeserver( - http_client=None, - homeserverToUse=GenericWorkerServer, + federation_http_client=None, + homeserver_to_use=GenericWorkerServer, config=self._get_worker_hs_config(), reactor=self.reactor, ) @@ -68,7 +75,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool self.test_handler = self._build_replication_data_handler() - self.worker_hs.replication_data_handler = self.test_handler + self.worker_hs._replication_data_handler = self.test_handler repl_handler = ReplicationCommandHandler(self.worker_hs) self.client = ClientReplicationStreamProtocol( @@ -78,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/replication"] = ReplicationRestResource(self.hs) + return d + def _get_worker_hs_config(self) -> dict: config = self.default_config() config["worker_app"] = "synapse.app.generic_worker" @@ -197,23 +209,41 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = self.hs.get_replication_streamer() + # Fake in memory Redis server that servers can connect to. + self._redis_server = FakeRedisPubSubServer() + store = self.hs.get_datastore() self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" + self.reactor.lookups["localhost"] = "127.0.0.1" - self._worker_hs_to_resource = {} + # A map from a HS instance to the associated HTTP Site to use for + # handling inbound HTTP requests to that instance. + self._hs_to_site = {self.hs: self.site} + + if self.hs.config.redis.redis_enabled: + # Handle attempts to connect to fake redis server. + self.reactor.add_tcp_client_callback( + "localhost", 6379, self.connect_any_redis_attempts, + ) + + self.hs.get_tcp_replication().start_replication(self.hs) # When we see a connection attempt to the master replication listener we # automatically set up the connection. This is so that tests don't # manually have to go and explicitly set it up each time (plus sometimes # it is impossible to write the handling explicitly in the tests). + # + # Register the master replication listener: self.reactor.add_tcp_client_callback( - "1.2.3.4", 8765, self._handle_http_replication_attempt + "1.2.3.4", + 8765, + lambda: self._handle_http_replication_attempt(self.hs, 8765), ) - def create_test_json_resource(self): - """Overrides `HomeserverTestCase.create_test_json_resource`. + def create_test_resource(self): + """Overrides `HomeserverTestCase.create_test_resource`. """ # We override this so that it automatically registers all the HTTP # replication servlets, without having to explicitly do that in all @@ -236,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): worker_app: Type of worker, e.g. `synapse.app.federation_sender`. extra_config: Any extra config to use for this instances. **kwargs: Options that get passed to `self.setup_test_homeserver`, - useful to e.g. pass some mocks for things like `http_client` + useful to e.g. pass some mocks for things like `federation_http_client` Returns: The new worker HomeServer instance. @@ -247,34 +277,69 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config.update(extra_config) worker_hs = self.setup_test_homeserver( - homeserverToUse=GenericWorkerServer, + homeserver_to_use=GenericWorkerServer, config=config, reactor=self.reactor, - **kwargs + **kwargs, ) + # If the instance is in the `instance_map` config then workers may try + # and send HTTP requests to it, so we register it with + # `_handle_http_replication_attempt` like we do with the master HS. + instance_name = worker_hs.get_instance_name() + instance_loc = worker_hs.config.worker.instance_map.get(instance_name) + if instance_loc: + # Ensure the host is one that has a fake DNS entry. + if instance_loc.host not in self.reactor.lookups: + raise Exception( + "Host does not have an IP for instance_map[%r].host = %r" + % (instance_name, instance_loc.host,) + ) + + self.reactor.add_tcp_client_callback( + self.reactor.lookups[instance_loc.host], + instance_loc.port, + lambda: self._handle_http_replication_attempt( + worker_hs, instance_loc.port + ), + ) + store = worker_hs.get_datastore() store.db_pool._db_pool = self.database_pool._db_pool - repl_handler = ReplicationCommandHandler(worker_hs) - client = ClientReplicationStreamProtocol( - worker_hs, "client", "test", self.clock, repl_handler, - ) - server = self.server_factory.buildProtocol(None) + # Set up TCP replication between master and the new worker if we don't + # have Redis support enabled. + if not worker_hs.config.redis_enabled: + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) - client_transport = FakeTransport(server, self.reactor) - client.makeConnection(client_transport) + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) - server_transport = FakeTransport(client, self.reactor) - server.makeConnection(server_transport) + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) # Set up a resource for the worker - resource = ReplicationRestResource(self.hs) + resource = ReplicationRestResource(worker_hs) for servlet in self.servlets: servlet(worker_hs, resource) - self._worker_hs_to_resource[worker_hs] = resource + self._hs_to_site[worker_hs] = SynapseSite( + logger_name="synapse.access.http.fake", + site_tag="{}-{}".format( + worker_hs.config.server.server_name, worker_hs.get_instance_name() + ), + config=worker_hs.config.server.listeners[0], + resource=resource, + server_version_string="1", + ) + + if worker_hs.config.redis.redis_enabled: + worker_hs.get_tcp_replication().start_replication(worker_hs) return worker_hs @@ -284,9 +349,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config["worker_replication_http_port"] = "8765" return config - def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): - render(request, self._worker_hs_to_resource[worker_hs], self.reactor) - def replicate(self): """Tell the master side of replication that something has happened, and then wait for the replication to occur. @@ -294,9 +356,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump() - def _handle_http_replication_attempt(self): - """Handles a connection attempt to the master replication HTTP - listener. + def _handle_http_replication_attempt(self, hs, repl_port): + """Handles a connection attempt to the given HS replication HTTP + listener on the given port. """ # We should have at least one outbound connection attempt, where the @@ -305,7 +367,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.assertGreaterEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, "1.2.3.4") - self.assertEqual(port, 8765) + self.assertEqual(port, repl_port) # Set up client side protocol client_protocol = client_factory.buildProtocol(None) @@ -315,7 +377,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Set up the server side protocol channel = _PushHTTPChannel(self.reactor) channel.requestFactory = request_factory - channel.site = self.site + channel.site = self._hs_to_site[hs] # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -333,6 +395,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # inside `connecTCP` before the connection has been passed back to the # code that requested the TCP connection. + def connect_any_redis_attempts(self): + """If redis is enabled we need to deal with workers connecting to a + redis server. We don't want to use a real Redis server so we use a + fake one. + """ + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "localhost") + self.assertEqual(port, 6379) + + client_protocol = client_factory.buildProtocol(None) + server_protocol = self._redis_server.buildProtocol(None) + + client_to_server_transport = FakeTransport( + server_protocol, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, server_protocol + ) + server_protocol.makeConnection(server_to_client_transport) + + return client_to_server_transport, server_to_client_transport + class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" @@ -467,3 +555,105 @@ class _PullToPushProducer: pass self.stopProducing() + + +class FakeRedisPubSubServer: + """A fake Redis server for pub/sub. + """ + + def __init__(self): + self._subscribers = set() + + def add_subscriber(self, conn): + """A connection has called SUBSCRIBE + """ + self._subscribers.add(conn) + + def remove_subscriber(self, conn): + """A connection has called UNSUBSCRIBE + """ + self._subscribers.discard(conn) + + def publish(self, conn, channel, msg) -> int: + """A connection want to publish a message to subscribers. + """ + for sub in self._subscribers: + sub.send(["message", channel, msg]) + + return len(self._subscribers) + + def buildProtocol(self, addr): + return FakeRedisPubSubProtocol(self) + + +class FakeRedisPubSubProtocol(Protocol): + """A connection from a client talking to the fake Redis server. + """ + + def __init__(self, server: FakeRedisPubSubServer): + self._server = server + self._reader = hiredis.Reader() + + def dataReceived(self, data): + self._reader.feed(data) + + # We might get multiple messages in one packet. + while True: + msg = self._reader.gets() + + if msg is False: + # No more messages. + return + + if not isinstance(msg, list): + # Inbound commands should always be a list + raise Exception("Expected redis list") + + self.handle_command(msg[0], *msg[1:]) + + def handle_command(self, command, *args): + """Received a Redis command from the client. + """ + + # We currently only support pub/sub. + if command == b"PUBLISH": + channel, message = args + num_subscribers = self._server.publish(self, channel, message) + self.send(num_subscribers) + elif command == b"SUBSCRIBE": + (channel,) = args + self._server.add_subscriber(self) + self.send(["subscribe", channel, 1]) + else: + raise Exception("Unknown command") + + def send(self, msg): + """Send a message back to the client. + """ + raw = self.encode(msg).encode("utf-8") + + self.transport.write(raw) + self.transport.flush() + + def encode(self, obj): + """Encode an object to its Redis format. + + Supports: strings/bytes, integers and list/tuples. + """ + + if isinstance(obj, bytes): + # We assume bytes are just unicode strings. + obj = obj.decode("utf-8") + + if isinstance(obj, str): + return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) + if isinstance(obj, int): + return ":{val}\r\n".format(val=obj) + if isinstance(obj, (list, tuple)): + items = "".join(self.encode(a) for a in obj) + return "*{len}\r\n{items}".format(len=len(obj), items=items) + + raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) + + def connectionLost(self, reason): + self._server.remove_subscriber(self) diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index c9998e88e6..bad0df08cf 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py
@@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase): sender=sender, type="test_event", content={"body": body}, - **kwargs + **kwargs, ) ) diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py new file mode 100644
index 0000000000..f35a5235e1 --- /dev/null +++ b/tests/replication/test_auth.py
@@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.rest.client.v2_alpha import register + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import FakeChannel, make_request +from tests.unittest import override_config + +logger = logging.getLogger(__name__) + + +class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): + """Test the authentication of HTTP calls between workers.""" + + servlets = [register.register_servlets] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + # This isn't a real configuration option but is used to provide the main + # homeserver and worker homeserver different options. + main_replication_secret = config.pop("main_replication_secret", None) + if main_replication_secret: + config["worker_replication_secret"] = main_replication_secret + return self.setup_test_homeserver(config=config) + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.client_reader" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + + return config + + def _test_register(self) -> FakeChannel: + """Run the actual test: + + 1. Create a worker homeserver. + 2. Start registration by providing a user/password. + 3. Complete registration by providing dummy auth (this hits the main synapse). + 4. Return the final request. + + """ + worker_hs = self.make_worker_hs("synapse.app.client_reader") + site = self._hs_to_site[worker_hs] + + channel_1 = make_request( + self.reactor, + site, + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel_1.code, 401) + + # Grab the session + session = channel_1.json_body["session"] + + # also complete the dummy auth + return make_request( + self.reactor, + site, + "POST", + "register", + {"auth": {"session": session, "type": "m.login.dummy"}}, + ) + + def test_no_auth(self): + """With no authentication the request should finish. + """ + channel = self._test_register() + self.assertEqual(channel.code, 200) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") + + @override_config({"main_replication_secret": "my-secret"}) + def test_missing_auth(self): + """If the main process expects a secret that is not provided, an error results. + """ + channel = self._test_register() + self.assertEqual(channel.code, 500) + + @override_config( + { + "main_replication_secret": "my-secret", + "worker_replication_secret": "wrong-secret", + } + ) + def test_unauthorized(self): + """If the main process receives the wrong secret, an error results. + """ + channel = self._test_register() + self.assertEqual(channel.code, 500) + + @override_config({"worker_replication_secret": "my-secret"}) + def test_authorized(self): + """The request should finish when the worker provides the authentication header. + """ + channel = self._test_register() + self.assertEqual(channel.code, 200) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 86c03fd89c..4608b65a0c 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py
@@ -14,27 +14,19 @@ # limitations under the License. import logging -from synapse.api.constants import LoginType -from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha import register from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker -from tests.server import FakeChannel +from tests.server import make_request logger = logging.getLogger(__name__) class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): - """Base class for tests of the replication streams""" + """Test using one or more client readers for registration.""" servlets = [register.register_servlets] - def prepare(self, reactor, clock, hs): - self.recaptcha_checker = DummyRecaptchaChecker(hs) - auth_handler = hs.get_auth_handler() - auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - def _get_worker_hs_config(self) -> dict: config = self.default_config() config["worker_app"] = "synapse.app.client_reader" @@ -46,24 +38,29 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): """Test that registration works when using a single client reader worker. """ worker_hs = self.make_worker_hs("synapse.app.client_reader") + site = self._hs_to_site[worker_hs] - request_1, channel_1 = self.make_request( + channel_1 = make_request( + self.reactor, + site, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, - ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs, request_1) - self.assertEqual(request_1.code, 401) + ) + self.assertEqual(channel_1.code, 401) # Grab the session session = channel_1.json_body["session"] # also complete the dummy auth - request_2, channel_2 = self.make_request( - "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} - ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs, request_2) - self.assertEqual(request_2.code, 200) + channel_2 = make_request( + self.reactor, + site, + "POST", + "register", + {"auth": {"session": session, "type": "m.login.dummy"}}, + ) + self.assertEqual(channel_2.code, 200) # We're given a registered user. self.assertEqual(channel_2.json_body["user_id"], "@user:test") @@ -74,23 +71,29 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") - request_1, channel_1 = self.make_request( + site_1 = self._hs_to_site[worker_hs_1] + channel_1 = make_request( + self.reactor, + site_1, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, - ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs_1, request_1) - self.assertEqual(request_1.code, 401) + ) + self.assertEqual(channel_1.code, 401) # Grab the session session = channel_1.json_body["session"] # also complete the dummy auth - request_2, channel_2 = self.make_request( - "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} - ) # type: SynapseRequest, FakeChannel - self.render_on_worker(worker_hs_2, request_2) - self.assertEqual(request_2.code, 200) + site_2 = self._hs_to_site[worker_hs_2] + channel_2 = make_request( + self.reactor, + site_2, + "POST", + "register", + {"auth": {"session": session, "type": "m.login.dummy"}}, + ) + self.assertEqual(channel_2.code, 200) # We're given a registered user. self.assertEqual(channel_2.json_body["user_id"], "@user:test") diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 23be1167a3..1853667558 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py
@@ -31,7 +31,7 @@ class FederationAckTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer) + hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer) return hs diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 1d7edee5ba..fffdb742c8 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py
@@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.federation_sender", {"send_federation": True}, - http_client=mock_client, + federation_http_client=mock_client, ) user = self.register_user("user", "pass") @@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client1, + federation_http_client=mock_client1, ) mock_client2 = Mock(spec=["put_json"]) @@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client2, + federation_http_client=mock_client2, ) user = self.register_user("user2", "pass") @@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client1, + federation_http_client=mock_client1, ) mock_client2 = Mock(spec=["put_json"]) @@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client2, + federation_http_client=mock_client2, ) user = self.register_user("user3", "pass") @@ -207,7 +207,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): def create_room_with_remote_server(self, user, token, remote_server="other_server"): room = self.helper.create_room_as(user, tok=token) store = self.hs.get_datastore() - federation = self.hs.get_handlers().federation_handler + federation = self.hs.get_federation_handler() prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) room_version = self.get_success(store.get_room_version(room)) @@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): } builder = factory.for_room_version(room_version, event_dict) - join_event = self.get_success(builder.build(prev_event_ids)) + join_event = self.get_success(builder.build(prev_event_ids, None)) self.get_success(federation.on_send_join_request(remote_server, join_event)) self.replicate() diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py new file mode 100644
index 0000000000..d1feca961f --- /dev/null +++ b/tests/replication/test_multi_media_repo.py
@@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from binascii import unhexlify +from typing import Tuple + +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web.http import HTTPChannel +from twisted.web.server import Request + +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.server import HomeServer + +from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import FakeChannel, FakeSite, FakeTransport, make_request + +logger = logging.getLogger(__name__) + +test_server_connection_factory = None + + +class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks running multiple media repos work correctly. + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + self.reactor.lookups["example.com"] = "1.2.3.4" + + def default_config(self): + conf = super().default_config() + conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] + return conf + + def _get_media_req( + self, hs: HomeServer, target: str, media_id: str + ) -> Tuple[FakeChannel, Request]: + """Request some remote media from the given HS by calling the download + API. + + This then triggers an outbound request from the HS to the target. + + Returns: + The channel for the *client* request and the *outbound* request for + the media which the caller should respond to. + """ + resource = hs.get_media_repository_resource().children[b"download"] + channel = make_request( + self.reactor, + FakeSite(resource), + "GET", + "/{}/{}".format(target, media_id), + shorthand=False, + access_token=self.access_token, + await_result=False, + ) + self.pump() + + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + + # build the test server + server_tls_protocol = _build_test_server(get_connection_factory()) + + # now, tell the client protocol factory to build the client protocol (it will be a + # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an + # HTTP11ClientProtocol) and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_tls_protocol, self.reactor, client_protocol) + ) + + # tell the server tls protocol to send its stuff back to the client, too + server_tls_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_tls_protocol) + ) + + # fish the test server back out of the server-side TLS protocol. + http_server = server_tls_protocol.wrappedProtocol + + # give the reactor a pump to get the TLS juices flowing. + self.reactor.pump((0.1,)) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + + self.assertEqual(request.method, b"GET") + self.assertEqual( + request.path, + "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"), + ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] + ) + + return channel, request + + def test_basic(self): + """Test basic fetching of remote media from a single worker. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + + channel, request = self._get_media_req(hs1, "example.com:443", "ABC123") + + request.setResponseCode(200) + request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request.write(b"Hello!") + request.finish() + + self.pump(0.1) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.result["body"], b"Hello!") + + def test_download_simple_file_race(self): + """Test that fetching remote media from two different processes at the + same time works. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_media() + + # Make two requests without responding to the outbound media requests. + channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123") + + # Respond to the first outbound media request and check that the client + # request is successful + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request1.write(b"Hello!") + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], b"Hello!") + + # Now respond to the second with the same content. + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request2.write(b"Hello!") + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], b"Hello!") + + # We expect only one new file to have been persisted. + self.assertEqual(start_count + 1, self._count_remote_media()) + + def test_download_image_race(self): + """Test that fetching remote *images* from two different processes at + the same time works. + + This checks that races generating thumbnails are handled correctly. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_thumbnails() + + channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1") + + png_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request1.write(png_data) + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], png_data) + + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request2.write(png_data) + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], png_data) + + # We expect only three new thumbnails to have been persisted. + self.assertEqual(start_count + 3, self._count_remote_thumbnails()) + + def _count_remote_media(self) -> int: + """Count the number of files in our remote media directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_content" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + def _count_remote_thumbnails(self) -> int: + """Count the number of files in our remote thumbnails directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_thumbnail" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + +def get_connection_factory(): + # this needs to happen once, but not until we are ready to run the first test + global test_server_connection_factory + if test_server_connection_factory is None: + test_server_connection_factory = TestServerTLSConnectionFactory( + sanlist=[b"DNS:example.com"] + ) + return test_server_connection_factory + + +def _build_test_server(connection_creator): + """Construct a test server + + This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol + + Args: + connection_creator (IOpenSSLServerConnectionCreator): thing to build + SSL connections + sanlist (list[bytes]): list of the SAN entries for the cert returned + by the server + + Returns: + TLSMemoryBIOProtocol + """ + server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + server_tls_factory = TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=server_factory + ) + + return server_tls_factory.buildProtocol(None) + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 2bdc6edbb1..800ad94a04 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): user_dict = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_dict["token_id"] + token_id = user_dict.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "https://push.example.com/push"}, + data={"url": "https://push.example.com/_matrix/push/v1/notify"}, ) ) @@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.pusher", {"start_pushers": True}, - proxied_http_client=http_client_mock, + proxied_blacklisted_http_client=http_client_mock, ) event_id = self._create_pusher_and_send_msg("user") @@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): http_client_mock.post_json_get_json.assert_called_once() self.assertEqual( http_client_mock.post_json_get_json.call_args[0][0], - "https://push.example.com/push", + "https://push.example.com/_matrix/push/v1/notify", ) self.assertEqual( event_id, @@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "pusher1", "pusher_instances": ["pusher1", "pusher2"], }, - proxied_http_client=http_client_mock1, + proxied_blacklisted_http_client=http_client_mock1, ) http_client_mock2 = Mock(spec_set=["post_json_get_json"]) @@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "pusher2", "pusher_instances": ["pusher1", "pusher2"], }, - proxied_http_client=http_client_mock2, + proxied_blacklisted_http_client=http_client_mock2, ) # We choose a user name that we know should go to pusher1. @@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): http_client_mock2.post_json_get_json.assert_not_called() self.assertEqual( http_client_mock1.post_json_get_json.call_args[0][0], - "https://push.example.com/push", + "https://push.example.com/_matrix/push/v1/notify", ) self.assertEqual( event_id, @@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): http_client_mock2.post_json_get_json.assert_called_once() self.assertEqual( http_client_mock2.post_json_get_json.call_args[0][0], - "https://push.example.com/push", + "https://push.example.com/_matrix/push/v1/notify", ) self.assertEqual( event_id, diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py new file mode 100644
index 0000000000..8d494ebc03 --- /dev/null +++ b/tests/replication/test_sharded_event_persister.py
@@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from mock import patch + +from synapse.api.room_versions import RoomVersion +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import make_request +from tests.utils import USE_POSTGRES_FOR_TESTS + +logger = logging.getLogger(__name__) + + +class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks event persisting sharding works + """ + + # Event persister sharding requires postgres (due to needing + # `MutliWriterIdGenerator`). + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # Register a user who sends a message that we'll get notified about + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + self.room_creator = self.hs.get_room_creation_handler() + self.store = hs.get_datastore() + + def default_config(self): + conf = super().default_config() + conf["redis"] = {"enabled": "true"} + conf["stream_writers"] = {"events": ["worker1", "worker2"]} + conf["instance_map"] = { + "worker1": {"host": "testserv", "port": 1001}, + "worker2": {"host": "testserv", "port": 1002}, + } + return conf + + def _create_room(self, room_id: str, user_id: str, tok: str): + """Create a room with given room_id + """ + + # We control the room ID generation by patching out the + # `_generate_room_id` method + async def generate_room( + creator_id: str, is_public: bool, room_version: RoomVersion + ): + await self.store.store_room( + room_id=room_id, + room_creator_user_id=creator_id, + is_public=is_public, + room_version=room_version, + ) + return room_id + + with patch( + "synapse.handlers.room.RoomCreationHandler._generate_room_id" + ) as mock: + mock.side_effect = generate_room + self.helper.create_room_as(user_id, tok=tok) + + def test_basic(self): + """Simple test to ensure that multiple rooms can be created and joined, + and that different rooms get handled by different instances. + """ + + self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker1"}, + ) + + self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker2"}, + ) + + persisted_on_1 = False + persisted_on_2 = False + + store = self.hs.get_datastore() + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Keep making new rooms until we see rooms being persisted on both + # workers. + for _ in range(10): + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room, user=self.other_user_id, tok=self.other_access_token + ) + + # The other user sends some messages + rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token) + event_id = rseponse["event_id"] + + # The event position includes which instance persisted the event. + pos = self.get_success(store.get_position_for_event(event_id)) + + persisted_on_1 |= pos.instance_name == "worker1" + persisted_on_2 |= pos.instance_name == "worker2" + + if persisted_on_1 and persisted_on_2: + break + + self.assertTrue(persisted_on_1) + self.assertTrue(persisted_on_2) + + def test_vector_clock_token(self): + """Tests that using a stream token with a vector clock component works + correctly with basic /sync and /messages usage. + """ + + self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker1"}, + ) + + worker_hs2 = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker2"}, + ) + + sync_hs = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "sync"}, + ) + sync_hs_site = self._hs_to_site[sync_hs] + + # Specially selected room IDs that get persisted on different workers. + room_id1 = "!foo:test" + room_id2 = "!baz:test" + + self.assertEqual( + self.hs.config.worker.events_shard_config.get_instance(room_id1), "worker1" + ) + self.assertEqual( + self.hs.config.worker.events_shard_config.get_instance(room_id2), "worker2" + ) + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + store = self.hs.get_datastore() + + # Create two room on the different workers. + self._create_room(room_id1, user_id, access_token) + self._create_room(room_id2, user_id, access_token) + + # The other user joins + self.helper.join( + room=room_id1, user=self.other_user_id, tok=self.other_access_token + ) + self.helper.join( + room=room_id2, user=self.other_user_id, tok=self.other_access_token + ) + + # Do an initial sync so that we're up to date. + channel = make_request( + self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token + ) + next_batch = channel.json_body["next_batch"] + + # We now gut wrench into the events stream MultiWriterIdGenerator on + # worker2 to mimic it getting stuck persisting an event. This ensures + # that when we send an event on worker1 we end up in a state where + # worker2 events stream position lags that on worker1, resulting in a + # RoomStreamToken with a non-empty instance map component. + # + # Worker2's event stream position will not advance until we call + # __aexit__ again. + actx = worker_hs2.get_datastore()._stream_id_gen.get_next() + self.get_success(actx.__aenter__()) + + response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token) + first_event_in_room1 = response["event_id"] + + # Assert that the current stream token has an instance map component, as + # we are trying to test vector clock tokens. + room_stream_token = store.get_room_max_token() + self.assertNotEqual(len(room_stream_token.instance_map), 0) + + # Check that syncing still gets the new event, despite the gap in the + # stream IDs. + channel = make_request( + self.reactor, + sync_hs_site, + "GET", + "/sync?since={}".format(next_batch), + access_token=access_token, + ) + + # We should only see the new event and nothing else + self.assertIn(room_id1, channel.json_body["rooms"]["join"]) + self.assertNotIn(room_id2, channel.json_body["rooms"]["join"]) + + events = channel.json_body["rooms"]["join"][room_id1]["timeline"]["events"] + self.assertListEqual( + [first_event_in_room1], [event["event_id"] for event in events] + ) + + # Get the next batch and makes sure its a vector clock style token. + vector_clock_token = channel.json_body["next_batch"] + self.assertTrue(vector_clock_token.startswith("m")) + + # Now that we've got a vector clock token we finish the fake persisting + # an event we started above. + self.get_success(actx.__aexit__(None, None, None)) + + # Now try and send an event to the other rooom so that we can test that + # the vector clock style token works as a `since` token. + response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token) + first_event_in_room2 = response["event_id"] + + channel = make_request( + self.reactor, + sync_hs_site, + "GET", + "/sync?since={}".format(vector_clock_token), + access_token=access_token, + ) + + self.assertNotIn(room_id1, channel.json_body["rooms"]["join"]) + self.assertIn(room_id2, channel.json_body["rooms"]["join"]) + + events = channel.json_body["rooms"]["join"][room_id2]["timeline"]["events"] + self.assertListEqual( + [first_event_in_room2], [event["event_id"] for event in events] + ) + + next_batch = channel.json_body["next_batch"] + + # We also want to test that the vector clock style token works with + # pagination. We do this by sending a couple of new events into the room + # and syncing again to get a prev_batch token for each room, then + # paginating from there back to the vector clock token. + self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token) + self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token) + + channel = make_request( + self.reactor, + sync_hs_site, + "GET", + "/sync?since={}".format(next_batch), + access_token=access_token, + ) + + prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][ + "prev_batch" + ] + prev_batch2 = channel.json_body["rooms"]["join"][room_id2]["timeline"][ + "prev_batch" + ] + + # Paginating back in the first room should not produce any results, as + # no events have happened in it. This tests that we are correctly + # filtering results based on the vector clock portion. + channel = make_request( + self.reactor, + sync_hs_site, + "GET", + "/rooms/{}/messages?from={}&to={}&dir=b".format( + room_id1, prev_batch1, vector_clock_token + ), + access_token=access_token, + ) + self.assertListEqual([], channel.json_body["chunk"]) + + # Paginating back on the second room should produce the first event + # again. This tests that pagination isn't completely broken. + channel = make_request( + self.reactor, + sync_hs_site, + "GET", + "/rooms/{}/messages?from={}&to={}&dir=b".format( + room_id2, prev_batch2, vector_clock_token + ), + access_token=access_token, + ) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertEqual( + channel.json_body["chunk"][0]["event_id"], first_event_in_room2 + ) + + # Paginating forwards should give the same results + channel = make_request( + self.reactor, + sync_hs_site, + "GET", + "/rooms/{}/messages?from={}&to={}&dir=f".format( + room_id1, vector_clock_token, prev_batch1 + ), + access_token=access_token, + ) + self.assertListEqual([], channel.json_body["chunk"]) + + channel = make_request( + self.reactor, + sync_hs_site, + "GET", + "/rooms/{}/messages?from={}&to={}&dir=f".format( + room_id2, vector_clock_token, prev_batch2, + ), + access_token=access_token, + ) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertEqual( + channel.json_body["chunk"][0]["event_id"], first_event_in_room2 + ) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 0f1144fe1e..0504cd187e 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py
@@ -30,19 +30,19 @@ from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import groups from tests import unittest +from tests.server import FakeSite, make_request class VersionTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/server_version" - def create_test_json_resource(self): + def create_test_resource(self): resource = JsonResource(self.hs) VersionServlet(self.hs).register(resource) return resource def test_version_string(self): - request, channel = self.make_request("GET", self.url, shorthand=False) - self.render(request) + channel = self.make_request("GET", self.url, shorthand=False) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -68,14 +68,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): def test_delete_group(self): # Create a new group - request, channel = self.make_request( + channel = self.make_request( "POST", "/create_group".encode("ascii"), access_token=self.admin_user_tok, content={"localpart": "test"}, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) group_id = channel.json_body["group_id"] @@ -85,17 +84,15 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): # Invite/join another user url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check other user knows they're in the group @@ -103,15 +100,14 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token)) # Now delete the group - url = "/admin/delete_group/" + group_id - request, channel = self.make_request( + url = "/_synapse/admin/v1/delete_group/" + group_id + channel = self.make_request( "POST", url.encode("ascii"), access_token=self.admin_user_tok, content={"localpart": "test"}, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check group returns 404 @@ -127,11 +123,10 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): """ url = "/groups/%s/profile" % (group_id,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.render(request) self.assertEqual( expect_code, int(channel.result["code"]), msg=channel.result["body"] ) @@ -139,11 +134,10 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token) """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/joined_groups".encode("ascii"), access_token=access_token ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) return channel.json_body["groups"] @@ -216,17 +210,20 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): } config["media_storage_providers"] = [provider_config] - hs = self.setup_test_homeserver(config=config, http_client=client) + hs = self.setup_test_homeserver(config=config, federation_http_client=client) return hs def _ensure_quarantined(self, admin_user_tok, server_and_media_id): """Ensure a piece of media is quarantined when trying to access it.""" - request, channel = self.make_request( - "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok, + channel = make_request( + self.reactor, + FakeSite(self.download_resource), + "GET", + server_and_media_id, + shorthand=False, + access_token=admin_user_tok, ) - request.render(self.download_resource) - self.pump(1.0) # Should be quarantined self.assertEqual( @@ -244,10 +241,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Attempt quarantine media APIs as non-admin url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345" - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=non_admin_user_tok, ) - self.render(request) # Expect a forbidden error self.assertEqual( @@ -258,10 +254,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # And the roomID/userID endpoint url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine" - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=non_admin_user_tok, ) - self.render(request) # Expect a forbidden error self.assertEqual( @@ -287,14 +282,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): server_name, media_id = server_name_and_media_id.split("/") # Attempt to access the media - request, channel = self.make_request( + channel = make_request( + self.reactor, + FakeSite(self.download_resource), "GET", server_name_and_media_id, shorthand=False, access_token=non_admin_user_tok, ) - request.render(self.download_resource) - self.pump(1.0) # Should be successful self.assertEqual(200, int(channel.code), msg=channel.result["body"]) @@ -304,8 +299,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): urllib.parse.quote(server_name), urllib.parse.quote(media_id), ) - request, channel = self.make_request("POST", url, access_token=admin_user_tok,) - self.render(request) + channel = self.make_request("POST", url, access_token=admin_user_tok,) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) @@ -357,8 +351,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote( room_id ) - request, channel = self.make_request("POST", url, access_token=admin_user_tok,) - self.render(request) + channel = self.make_request("POST", url, access_token=admin_user_tok,) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) self.assertEqual( @@ -402,10 +395,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( non_admin_user ) - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=admin_user_tok, ) - self.render(request) self.pump(1.0) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -445,10 +437,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( non_admin_user ) - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), access_token=admin_user_tok, ) - self.render(request) self.pump(1.0) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -462,14 +453,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self._ensure_quarantined(admin_user_tok, server_and_media_id_1) # Attempt to access each piece of media - request, channel = self.make_request( + channel = make_request( + self.reactor, + FakeSite(self.download_resource), "GET", server_and_media_id_2, shorthand=False, access_token=non_admin_user_tok, ) - request.render(self.download_resource) - self.pump(1.0) # Shouldn't be quarantined self.assertEqual( diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 92c9058887..248c4442c3 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py
@@ -50,20 +50,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ Try to get a device of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - request, channel = self.make_request("PUT", self.url, b"{}") - self.render(request) + channel = self.make_request("PUT", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - request, channel = self.make_request("DELETE", self.url, b"{}") - self.render(request) + channel = self.make_request("DELETE", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -72,26 +69,23 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error is returned. """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url, access_token=self.other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url, access_token=self.other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - request, channel = self.make_request( + channel = self.make_request( "DELETE", self.url, access_token=self.other_user_token, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -105,26 +99,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): % self.other_user_device_id ) - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - request, channel = self.make_request( - "PUT", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -138,26 +123,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): % self.other_user_device_id ) - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - request, channel = self.make_request( - "PUT", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -170,25 +146,16 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.other_user ) - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - request, channel = self.make_request( - "PUT", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) - request, channel = self.make_request( - "DELETE", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) # Delete unknown device returns status 200 self.assertEqual(200, channel.code, msg=channel.json_body) @@ -212,22 +179,18 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): } body = json.dumps(update) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -244,18 +207,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( - "PUT", self.url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure the display name was not updated. - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -266,21 +223,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ # Set new display_name body = json.dumps({"display_name": "new displayname"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) # Check new display_name - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) @@ -289,10 +242,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ Tests that a normal lookup for a device is successfully """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) @@ -313,10 +263,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(1, number_devices) # Delete device - request, channel = self.make_request( + channel = self.make_request( "DELETE", self.url, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -346,8 +295,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ Try to list devices of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -358,10 +306,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url, access_token=other_user_token, - ) - self.render(request) + channel = self.make_request("GET", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -371,10 +316,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/devices" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -385,14 +327,24 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) + def test_user_has_no_devices(self): + """ + Tests that a normal lookup for devices is successfully + if user has no devices + """ + + # Get devices + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["devices"])) + def test_get_devices(self): """ Tests that a normal lookup for devices is successfully @@ -403,12 +355,10 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.login("user", "pass") # Get devices - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_devices, channel.json_body["total"]) self.assertEqual(number_devices, len(channel.json_body["devices"])) self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) # Check that all fields are available @@ -443,8 +393,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ Try to delete devices of an user without authentication. """ - request, channel = self.make_request("POST", self.url, b"{}") - self.render(request) + channel = self.make_request("POST", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -455,10 +404,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "POST", self.url, access_token=other_user_token, - ) - self.render(request) + channel = self.make_request("POST", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -468,10 +414,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" - request, channel = self.make_request( - "POST", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("POST", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -482,10 +425,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" - request, channel = self.make_request( - "POST", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("POST", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -495,13 +435,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): Tests that a remove of a device that does not exist returns 200. """ body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) # Delete unknown devices returns status 200 self.assertEqual(200, channel.code, msg=channel.json_body) @@ -527,13 +466,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): # Delete devices body = json.dumps({"devices": device_ids}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index bf79086f78..aa389df12f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py
@@ -70,15 +70,21 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/event_reports" + def test_no_auth(self): + """ + Try to get an event report without authentication. + """ + channel = self.make_request("GET", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error 403 is returned. """ - request, channel = self.make_request( - "GET", self.url, access_token=self.other_user_tok, - ) - self.render(request) + channel = self.make_request("GET", self.url, access_token=self.other_user_tok,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -88,10 +94,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events """ - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -104,10 +107,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with limit """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -120,10 +122,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a defined starting point (from) """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -136,10 +137,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a defined starting point and limit """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -152,12 +152,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a filter of room """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?room_id=%s" % self.room_id1, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 10) @@ -173,12 +172,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a filter of user """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?user_id=%s" % self.other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 10) @@ -194,12 +192,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events with a filter of user and room """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 5) @@ -217,10 +214,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ # fetch the most recent first, largest timestamp - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?dir=b", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -234,10 +230,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): report += 1 # fetch the oldest first, smallest timestamp - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?dir=f", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -255,10 +250,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing that a invalid search order returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -266,13 +260,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase): def test_limit_is_negative(self): """ - Testing that a negative list parameter returns a 400 + Testing that a negative limit parameter returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -282,10 +275,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing that a negative from parameter returns a 400 """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=-5", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) @@ -297,10 +289,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of results is the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=20", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -309,10 +300,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of max results is larger than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=21", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -321,10 +311,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does appear # Number of max results is smaller than the number of entries - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?limit=19", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -334,10 +323,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # Check # Set `from` to value of `next_token` for request remaining entries # `next_token` does not appear - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?from=19", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -350,17 +338,16 @@ class EventReportsTestCase(unittest.HomeserverTestCase): resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] - request, channel = self.make_request( + channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), json.dumps({"score": -100, "reason": "this makes me sad"}), access_token=user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def _check_fields(self, content): - """Checks that all attributes are present in a event report + """Checks that all attributes are present in an event report """ for c in content: self.assertIn("id", c) @@ -368,15 +355,163 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertIn("room_id", c) self.assertIn("event_id", c) self.assertIn("user_id", c) - self.assertIn("reason", c) - self.assertIn("content", c) self.assertIn("sender", c) - self.assertIn("room_alias", c) - self.assertIn("event_json", c) - self.assertIn("score", c["content"]) - self.assertIn("reason", c["content"]) - self.assertIn("auth_events", c["event_json"]) - self.assertIn("type", c["event_json"]) - self.assertIn("room_id", c["event_json"]) - self.assertIn("sender", c["event_json"]) - self.assertIn("content", c["event_json"]) + self.assertIn("canonical_alias", c) + self.assertIn("name", c) + self.assertIn("score", c) + self.assertIn("reason", c) + + +class EventReportDetailTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id1 = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok) + + self._create_event_and_report( + room_id=self.room_id1, user_tok=self.other_user_tok, + ) + + # first created event report gets `id`=2 + self.url = "/_synapse/admin/v1/event_reports/2" + + def test_no_auth(self): + """ + Try to get event report without authentication. + """ + channel = self.make_request("GET", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request("GET", self.url, access_token=self.other_user_tok,) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_default_success(self): + """ + Testing get a reported event + """ + + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self._check_fields(channel.json_body) + + def test_invalid_report_id(self): + """ + Testing that an invalid `report_id` returns a 400. + """ + + # `report_id` is negative + channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/-123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is a non-numerical string + channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/abcdef", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is undefined + channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + def test_report_id_not_found(self): + """ + Testing that a not existing `report_id` returns a 404. + """ + + channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual("Event report not found", channel.json_body["error"]) + + def _create_event_and_report(self, room_id, user_tok): + """Create and report events + """ + resp = self.helper.send(room_id, tok=user_tok) + event_id = resp["event_id"] + + channel = self.make_request( + "POST", + "rooms/%s/report/%s" % (room_id, event_id), + json.dumps({"score": -100, "reason": "this makes me sad"}), + access_token=user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def _check_fields(self, content): + """Checks that all attributes are present in a event report + """ + self.assertIn("id", content) + self.assertIn("received_ts", content) + self.assertIn("room_id", content) + self.assertIn("event_id", content) + self.assertIn("user_id", content) + self.assertIn("sender", content) + self.assertIn("canonical_alias", content) + self.assertIn("name", content) + self.assertIn("event_json", content) + self.assertIn("score", content) + self.assertIn("reason", content) + self.assertIn("auth_events", content["event_json"]) + self.assertIn("type", content["event_json"]) + self.assertIn("room_id", content["event_json"]) + self.assertIn("sender", content["event_json"]) + self.assertIn("content", content["event_json"]) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py new file mode 100644
index 0000000000..c2b998cdae --- /dev/null +++ b/tests/rest/admin/test_media.py
@@ -0,0 +1,535 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from binascii import unhexlify + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, profile, room +from synapse.rest.media.v1.filepath import MediaFilePaths + +from tests import unittest +from tests.server import FakeSite, make_request + + +class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + self.media_repo = hs.get_media_repository_resource() + self.server_name = hs.hostname + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.filepaths = MediaFilePaths(hs.config.media_store_path) + + def test_no_auth(self): + """ + Try to delete media without authentication. + """ + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + channel = self.make_request("DELETE", url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + channel = self.make_request("DELETE", url, access_token=self.other_user_token,) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_media_does_not_exist(self): + """ + Tests that a lookup for a media that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_media_is_not_local(self): + """ + Tests that a lookup for a media that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") + + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only delete local media", channel.json_body["error"]) + + def test_delete_media(self): + """ + Tests that delete a media is successfully + """ + + download_resource = self.media_repo.children[b"download"] + upload_resource = self.media_repo.children[b"upload"] + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + server_name, media_id = server_and_media_id.split("/") + + self.assertEqual(server_name, self.server_name) + + # Attempt to access media + channel = make_request( + self.reactor, + FakeSite(download_resource), + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + + # Should be successful + self.assertEqual( + 200, + channel.code, + msg=( + "Expected to receive a 200 on accessing media: %s" % server_and_media_id + ), + ) + + # Test if the file exists + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id) + + # Delete media + channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + media_id, channel.json_body["deleted_media"][0], + ) + + # Attempt to access media + channel = make_request( + self.reactor, + FakeSite(download_resource), + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + self.assertEqual( + 404, + channel.code, + msg=( + "Expected to receive a 404 on accessing deleted media: %s" + % server_and_media_id + ), + ) + + # Test if the file is deleted + self.assertFalse(os.path.exists(local_path)) + + +class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + self.media_repo = hs.get_media_repository_resource() + self.server_name = hs.hostname + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name + + def test_no_auth(self): + """ + Try to delete media without authentication. + """ + + channel = self.make_request("POST", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + channel = self.make_request( + "POST", self.url, access_token=self.other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_media_is_not_local(self): + """ + Tests that a lookup for media that is not local returns a 400 + """ + url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" + + channel = self.make_request( + "POST", url + "?before_ts=1234", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only delete local media", channel.json_body["error"]) + + def test_missing_parameter(self): + """ + If the parameter `before_ts` is missing, an error is returned. + """ + channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Missing integer query parameter b'before_ts'", channel.json_body["error"] + ) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + channel = self.make_request( + "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter before_ts must be a string representing a positive integer.", + channel.json_body["error"], + ) + + channel = self.make_request( + "POST", + self.url + "?before_ts=1234&size_gt=-1234", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter size_gt must be a string representing a positive integer.", + channel.json_body["error"], + ) + + channel = self.make_request( + "POST", + self.url + "?before_ts=1234&keep_profiles=not_bool", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual( + "Boolean query parameter b'keep_profiles' must be one of ['true', 'false']", + channel.json_body["error"], + ) + + def test_delete_media_never_accessed(self): + """ + Tests that media deleted if it is older than `before_ts` and never accessed + `last_access_ts` is `NULL` and `created_ts` < `before_ts` + """ + + # upload and do not access + server_and_media_id = self._create_media() + self.pump(1.0) + + # test that the file exists + media_id = server_and_media_id.split("/")[1] + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + + # timestamp after upload/create + now_ms = self.clock.time_msec() + channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + media_id, channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_date(self): + """ + Tests that media is not deleted if it is newer than `before_ts` + """ + + # timestamp before upload + now_ms = self.clock.time_msec() + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + # timestamp after upload + now_ms = self.clock.time_msec() + channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_size(self): + """ + Tests that media is not deleted if its size is smaller than or equal + to `size_gt` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_user_avatar(self): + """ + Tests that we do not delete media if is used as a user avatar + Tests parameter `keep_profiles` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + # set media as avatar + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.admin_user,), + content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + now_ms = self.clock.time_msec() + channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_room_avatar(self): + """ + Tests that we do not delete media if it is used as a room avatar + Tests parameter `keep_profiles` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + # set media as room avatar + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + channel = self.make_request( + "PUT", + "/rooms/%s/state/m.room.avatar" % (room_id,), + content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + now_ms = self.clock.time_msec() + channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def _create_media(self): + """ + Create a media and return media_id and server_and_media_id + """ + upload_resource = self.media_repo.children[b"upload"] + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + server_name = server_and_media_id.split("/")[0] + + # Check that new media is a local and not remote + self.assertEqual(server_name, self.server_name) + + return server_and_media_id + + def _access_media(self, server_and_media_id, expect_success=True): + """ + Try to access a media and check the result + """ + download_resource = self.media_repo.children[b"download"] + + media_id = server_and_media_id.split("/")[1] + local_path = self.filepaths.local_media_filepath(media_id) + + channel = make_request( + self.reactor, + FakeSite(download_resource), + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + + if expect_success: + self.assertEqual( + 200, + channel.code, + msg=( + "Expected to receive a 200 on accessing media: %s" + % server_and_media_id + ), + ) + # Test that the file exists + self.assertTrue(os.path.exists(local_path)) + else: + self.assertEqual( + 404, + channel.code, + msg=( + "Expected to receive a 404 on accessing deleted media: %s" + % (server_and_media_id) + ), + ) + # Test that the file is deleted + self.assertFalse(os.path.exists(local_path)) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 6dfc709dc5..fa620f97f3 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py
@@ -20,6 +20,7 @@ from typing import List, Optional from mock import Mock import synapse.rest.admin +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes from synapse.rest.client.v1 import directory, events, login, room @@ -78,14 +79,13 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): ) # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( + url = "/_synapse/admin/v1/shutdown_room/" + room_id + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -104,24 +104,22 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): # Enable world readable url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_token, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Test that the admin can still send shutdown - url = "admin/shutdown_room/" + room_id - request, channel = self.make_request( + url = "/_synapse/admin/v1/shutdown_room/" + room_id + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -133,19 +131,17 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): """ url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.render(request) self.assertEqual( expect_code, int(channel.result["code"]), msg=channel.result["body"] ) url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.render(request) self.assertEqual( expect_code, int(channel.result["code"]), msg=channel.result["body"] ) @@ -189,10 +185,9 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): If the user is not a server admin, an error 403 is returned. """ - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, json.dumps({}), access_token=self.other_user_tok, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -203,10 +198,9 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/rooms/!unknown:test/delete" - request, channel = self.make_request( + channel = self.make_request( "POST", url, json.dumps({}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -217,10 +211,9 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/rooms/invalidroom/delete" - request, channel = self.make_request( + channel = self.make_request( "POST", url, json.dumps({}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -233,13 +226,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"new_room_user_id": "@unknown:test"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertIn("new_room_id", channel.json_body) @@ -253,13 +245,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"new_room_user_id": "@not:exist.bla"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -272,13 +263,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"block": "NotBool"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) @@ -289,13 +279,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"purge": "NotBool"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) @@ -316,13 +305,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": True, "purge": True}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url.encode("ascii"), content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(None, channel.json_body["new_room_id"]) @@ -350,13 +338,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": False, "purge": True}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url.encode("ascii"), content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(None, channel.json_body["new_room_id"]) @@ -385,13 +372,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"block": False, "purge": False}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url.encode("ascii"), content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(None, channel.json_body["new_room_id"]) @@ -433,13 +419,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Test that the admin can still send shutdown url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) @@ -464,13 +449,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Enable world readable url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Test that room is not purged @@ -482,13 +466,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Test that the admin can still send shutdown url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) @@ -531,40 +514,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def _is_purged(self, room_id): """Test that the following tables have been purged of all rows related to the room. """ - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "local_invites", - "room_account_data", - "room_tags", - # "state_groups", # Current impl leaves orphaned state groups around. - "state_groups_state", - ): + for table in PURGE_TABLES: count = self.get_success( self.store.db_pool.simple_select_one_onecol( table=table, @@ -581,19 +531,17 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ url = "rooms/%s/initialSync" % (room_id,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.render(request) self.assertEqual( expect_code, int(channel.result["code"]), msg=channel.result["body"] ) url = "events?timeout=0&room_id=" + room_id - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.render(request) self.assertEqual( expect_code, int(channel.result["code"]), msg=channel.result["body"] ) @@ -622,50 +570,17 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) url = "/_synapse/admin/v1/purge_room" - request, channel = self.make_request( + channel = self.make_request( "POST", url.encode("ascii"), {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Test that the following tables have been purged of all rows related to the room. - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "room_account_data", - "room_tags", - # "state_groups", # Current impl leaves orphaned state groups around. - "state_groups_state", - ): + for table in PURGE_TABLES: count = self.get_success( self.store.db_pool.simple_select_one_onecol( table=table, @@ -709,10 +624,9 @@ class RoomTestCase(unittest.HomeserverTestCase): # Request the list of rooms url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) # Check request completed successfully self.assertEqual(200, int(channel.code), msg=channel.json_body) @@ -791,10 +705,9 @@ class RoomTestCase(unittest.HomeserverTestCase): limit, "name", ) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual( 200, int(channel.result["code"]), msg=channel.result["body"] ) @@ -832,10 +745,9 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_ids, returned_room_ids) url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def test_correct_room_attributes(self): @@ -853,13 +765,12 @@ class RoomTestCase(unittest.HomeserverTestCase): # Create a new alias to this room url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Set this new alias as the canonical alias for this room @@ -884,10 +795,9 @@ class RoomTestCase(unittest.HomeserverTestCase): # Request the list of rooms url = "/_synapse/admin/v1/rooms" - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check that rooms were returned @@ -926,13 +836,12 @@ class RoomTestCase(unittest.HomeserverTestCase): url = "/_matrix/client/r0/directory/room/%s" % ( urllib.parse.quote(test_alias), ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url.encode("ascii"), {"room_id": room_id}, access_token=admin_user_tok, ) - self.render(request) self.assertEqual( 200, int(channel.result["code"]), msg=channel.result["body"] ) @@ -967,10 +876,9 @@ class RoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,) if reverse: url += "&dir=b" - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned @@ -1104,10 +1012,9 @@ class RoomTestCase(unittest.HomeserverTestCase): expected_http_code: The expected http code for the request """ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) if expected_http_code != 200: @@ -1144,6 +1051,13 @@ class RoomTestCase(unittest.HomeserverTestCase): _search_test(room_id_2, "else") _search_test(room_id_2, "se") + # Test case insensitive + _search_test(room_id_1, "SOMETHING") + _search_test(room_id_1, "THING") + + _search_test(room_id_2, "ELSE") + _search_test(room_id_2, "SE") + _search_test(None, "foo") _search_test(None, "bar") _search_test(None, "", expected_http_code=400) @@ -1166,10 +1080,9 @@ class RoomTestCase(unittest.HomeserverTestCase): ) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("room_id", channel.json_body) @@ -1179,6 +1092,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("canonical_alias", channel.json_body) self.assertIn("joined_members", channel.json_body) self.assertIn("joined_local_members", channel.json_body) + self.assertIn("joined_local_devices", channel.json_body) self.assertIn("version", channel.json_body) self.assertIn("creator", channel.json_body) self.assertIn("encryption", channel.json_body) @@ -1191,6 +1105,39 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id_1, channel.json_body["room_id"]) + def test_single_room_devices(self): + """Test that `joined_local_devices` can be requested correctly""" + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["joined_local_devices"]) + + # Have another user join the room + user_1 = self.register_user("foo", "pass") + user_tok_1 = self.login("foo", "pass") + self.helper.join(room_id_1, user_1, tok=user_tok_1) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(2, channel.json_body["joined_local_devices"]) + + # leave room + self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok) + self.helper.leave(room_id_1, user_1, tok=user_tok_1) + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["joined_local_devices"]) + def test_room_members(self): """Test that room members can be requested correctly""" # Create two test rooms @@ -1214,10 +1161,9 @@ class RoomTestCase(unittest.HomeserverTestCase): self.helper.join(room_id_2, user_3, tok=user_tok_3) url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( @@ -1226,10 +1172,9 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["total"], 3) url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) - request, channel = self.make_request( + channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( @@ -1267,13 +1212,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), access_token=self.second_tok, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1284,13 +1228,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"unknown_parameter": "@unknown:test"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) @@ -1301,13 +1244,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": "@unknown:test"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1318,13 +1260,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": "@not:exist.bla"}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -1339,13 +1280,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/!unknown:test" - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("No known servers", channel.json_body["error"]) @@ -1357,13 +1297,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/invalidroom" - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -1377,23 +1316,21 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", self.url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.render(request) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) @@ -1408,13 +1345,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/join/{}".format(private_room_id) body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1439,10 +1375,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if server admin is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.render(request) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) @@ -1451,22 +1386,20 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/join/{}".format(private_room_id) body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.render(request) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) @@ -1481,22 +1414,192 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/join/{}".format(private_room_id) body = json.dumps({"user_id": self.second_user_id}) - request, channel = self.make_request( + channel = self.make_request( "POST", url, content=body.encode(encoding="utf_8"), access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.render(request) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + + +class MakeRoomAdminTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.public_room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + self.url = "/_synapse/admin/v1/rooms/{}/make_room_admin".format( + self.public_room_id + ) + + def test_public_room(self): + """Test that getting admin in a public room works. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room and ban a user. + self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) + self.helper.change_membership( + room_id, + self.admin_user, + "@test:test", + Membership.BAN, + tok=self.admin_user_tok, + ) + + def test_private_room(self): + """Test that getting admin in a private room works and we get invited. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=False, + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room (we should have received an + # invite) and can ban a user. + self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) + self.helper.change_membership( + room_id, + self.admin_user, + "@test:test", + Membership.BAN, + tok=self.admin_user_tok, + ) + + def test_other_user(self): + """Test that giving admin in a public room works to a non-admin user works. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={"user_id": self.second_user_id}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Now we test that we can join the room and ban a user. + self.helper.join(room_id, self.second_user_id, tok=self.second_tok) + self.helper.change_membership( + room_id, + self.second_user_id, + "@test:test", + Membership.BAN, + tok=self.second_tok, + ) + + def test_not_enough_power(self): + """Test that we get a sensible error if there are no local room admins. + """ + room_id = self.helper.create_room_as( + self.creator, tok=self.creator_tok, is_public=True + ) + + # The creator drops admin rights in the room. + pl = self.helper.get_state( + room_id, EventTypes.PowerLevels, tok=self.creator_tok + ) + pl["users"][self.creator] = 0 + self.helper.send_state( + room_id, EventTypes.PowerLevels, body=pl, tok=self.creator_tok + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + content={}, + access_token=self.admin_user_tok, + ) + + # We expect this to fail with a 400 as there are no room admins. + # + # (Note we assert the error message to ensure that it's not denied for + # some other reason) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + channel.json_body["error"], + "No local admin user in room with power to update power levels.", + ) + + +PURGE_TABLES = [ + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "room_account_data", + "room_tags", + # "state_groups", # Current impl leaves orphaned state groups around. + "state_groups_state", +] diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py new file mode 100644
index 0000000000..73f8a8ec99 --- /dev/null +++ b/tests/rest/admin/test_statistics.py
@@ -0,0 +1,452 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from binascii import unhexlify +from typing import Any, Dict, List, Optional + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login + +from tests import unittest + + +class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.media_repo = hs.get_media_repository_resource() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.url = "/_synapse/admin/v1/statistics/users/media" + + def test_no_auth(self): + """ + Try to list users without authentication. + """ + channel = self.make_request("GET", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + channel = self.make_request( + "GET", self.url, json.dumps({}), access_token=self.other_user_tok, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + # unkown order_by + channel = self.make_request( + "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative limit + channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from_ts + channel = self.make_request( + "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative until_ts + channel = self.make_request( + "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # until_ts smaller from_ts + channel = self.make_request( + "GET", + self.url + "?from_ts=10&until_ts=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # empty search term + channel = self.make_request( + "GET", self.url + "?search_term=", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # invalid search order + channel = self.make_request( + "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_limit(self): + """ + Testing list of media with limit + """ + self._create_users_with_media(10, 2) + + channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 10) + self.assertEqual(len(channel.json_body["users"]), 5) + self.assertEqual(channel.json_body["next_token"], 5) + self._check_fields(channel.json_body["users"]) + + def test_from(self): + """ + Testing list of media with a defined starting point (from) + """ + self._create_users_with_media(20, 2) + + channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["users"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["users"]) + + def test_limit_and_from(self): + """ + Testing list of media with a defined starting point and limit + """ + self._create_users_with_media(20, 2) + + channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(channel.json_body["next_token"], 15) + self.assertEqual(len(channel.json_body["users"]), 10) + self._check_fields(channel.json_body["users"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_users = 20 + self._create_users_with_media(number_users, 3) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 19) + self.assertEqual(channel.json_body["next_token"], 19) + + # Set `from` to value of `next_token` for request remaining entries + # Check `next_token` does not appear + channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def test_no_media(self): + """ + Tests that a normal lookup for statistics is successfully + if users have no media created + """ + + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["users"])) + + def test_order_by(self): + """ + Testing order list with parameter `order_by` + """ + + # create users + self.register_user("user_a", "pass", displayname="UserZ") + userA_tok = self.login("user_a", "pass") + self._create_media(userA_tok, 1) + + self.register_user("user_b", "pass", displayname="UserY") + userB_tok = self.login("user_b", "pass") + self._create_media(userB_tok, 3) + + self.register_user("user_c", "pass", displayname="UserX") + userC_tok = self.login("user_c", "pass") + self._create_media(userC_tok, 2) + + # order by user_id + self._order_test("user_id", ["@user_a:test", "@user_b:test", "@user_c:test"]) + self._order_test( + "user_id", ["@user_a:test", "@user_b:test", "@user_c:test"], "f", + ) + self._order_test( + "user_id", ["@user_c:test", "@user_b:test", "@user_a:test"], "b", + ) + + # order by displayname + self._order_test( + "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"] + ) + self._order_test( + "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"], "f", + ) + self._order_test( + "displayname", ["@user_a:test", "@user_b:test", "@user_c:test"], "b", + ) + + # order by media_length + self._order_test( + "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"], + ) + self._order_test( + "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"], "f", + ) + self._order_test( + "media_length", ["@user_b:test", "@user_c:test", "@user_a:test"], "b", + ) + + # order by media_count + self._order_test( + "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"], + ) + self._order_test( + "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"], "f", + ) + self._order_test( + "media_count", ["@user_b:test", "@user_c:test", "@user_a:test"], "b", + ) + + def test_from_until_ts(self): + """ + Testing filter by time with parameters `from_ts` and `until_ts` + """ + # create media earlier than `ts1` to ensure that `from_ts` is working + self._create_media(self.other_user_tok, 3) + self.pump(1) + ts1 = self.clock.time_msec() + + # list all media when filter is not set + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["users"][0]["media_count"], 3) + + # filter media starting at `ts1` after creating first media + # result is 0 + channel = self.make_request( + "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 0) + + self._create_media(self.other_user_tok, 3) + self.pump(1) + ts2 = self.clock.time_msec() + # create media after `ts2` to ensure that `until_ts` is working + self._create_media(self.other_user_tok, 3) + + # filter media between `ts1` and `ts2` + channel = self.make_request( + "GET", + self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["users"][0]["media_count"], 3) + + # filter media until `ts2` and earlier + channel = self.make_request( + "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["users"][0]["media_count"], 6) + + def test_search_term(self): + self._create_users_with_media(20, 1) + + # check without filter get all users + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + + # filter user 1 and 10-19 by `user_id` + channel = self.make_request( + "GET", + self.url + "?search_term=foo_user_1", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 11) + + # filter on this user in `displayname` + channel = self.make_request( + "GET", + self.url + "?search_term=bar_user_10", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10") + self.assertEqual(channel.json_body["total"], 1) + + # filter and get empty result + channel = self.make_request( + "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 0) + + def _create_users_with_media(self, number_users: int, media_per_user: int): + """ + Create a number of users with a number of media + Args: + number_users: Number of users to be created + media_per_user: Number of media to be created for each user + """ + for i in range(number_users): + self.register_user("foo_user_%s" % i, "pass", displayname="bar_user_%s" % i) + user_tok = self.login("foo_user_%s" % i, "pass") + self._create_media(user_tok, media_per_user) + + def _create_media(self, user_token: str, number_media: int): + """ + Create a number of media for a specific user + Args: + user_token: Access token of the user + number_media: Number of media to be created for the user + """ + upload_resource = self.media_repo.children[b"upload"] + for i in range(number_media): + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + self.helper.upload_media( + upload_resource, image_data, tok=user_token, expect_code=200 + ) + + def _check_fields(self, content: List[Dict[str, Any]]): + """Checks that all attributes are present in content + Args: + content: List that is checked for content + """ + for c in content: + self.assertIn("user_id", c) + self.assertIn("displayname", c) + self.assertIn("media_count", c) + self.assertIn("media_length", c) + + def _order_test( + self, order_type: str, expected_user_list: List[str], dir: Optional[str] = None + ): + """Request the list of users in a certain order. Assert that order is what + we expect + Args: + order_type: The type of ordering to give the server + expected_user_list: The list of user_ids in the order we expect to get + back from the server + dir: The direction of ordering to give the server + """ + + url = self.url + "?order_by=%s" % (order_type,) + if dir is not None and dir in ("b", "f"): + url += "&dir=%s" % (dir,) + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], len(expected_user_list)) + + returned_order = [row["user_id"] for row in channel.json_body["users"]] + self.assertListEqual(expected_user_list, returned_order) + self._check_fields(channel.json_body["users"]) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 98d0623734..9b2e4765f6 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -17,14 +17,16 @@ import hashlib import hmac import json import urllib.parse +from binascii import unhexlify +from typing import Optional from mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v1 import login, logout, profile, room +from synapse.rest.client.v2_alpha import devices, sync from tests import unittest from tests.test_utils import make_awaitable @@ -33,11 +35,14 @@ from tests.unittest import override_config class UserRegisterTestCase(unittest.HomeserverTestCase): - servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + profile.register_servlets, + ] def make_homeserver(self, reactor, clock): - self.url = "/_matrix/client/r0/admin/register" + self.url = "/_synapse/admin/v1/register" self.registration_handler = Mock() self.identity_handler = Mock() @@ -66,8 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ self.hs.config.registration_shared_secret = None - request, channel = self.make_request("POST", self.url, b"{}") - self.render(request) + channel = self.make_request("POST", self.url, b"{}") self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( @@ -84,8 +88,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs.get_secrets = Mock(return_value=secrets) - request, channel = self.make_request("GET", self.url) - self.render(request) + channel = self.make_request("GET", self.url) self.assertEqual(channel.json_body, {"nonce": "abcd"}) @@ -94,16 +97,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): Calling GET on the endpoint will return a randomised nonce, which will only last for SALT_TIMEOUT (60s). """ - request, channel = self.make_request("GET", self.url) - self.render(request) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] # 59 seconds self.reactor.advance(59) body = json.dumps({"nonce": nonce}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("username must be specified", channel.json_body["error"]) @@ -111,8 +112,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # 61 seconds self.reactor.advance(2) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -121,8 +121,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ Only the provided nonce can be used, as it's checked in the MAC. """ - request, channel = self.make_request("GET", self.url) - self.render(request) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -138,8 +137,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("HMAC incorrect", channel.json_body["error"]) @@ -149,8 +147,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): When the correct nonce is provided, and the right key is provided, the user is registered. """ - request, channel = self.make_request("GET", self.url) - self.render(request) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -169,8 +166,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -179,8 +175,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ A valid unrecognised nonce. """ - request, channel = self.make_request("GET", self.url) - self.render(request) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -196,15 +191,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -217,8 +210,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): """ def nonce(): - request, channel = self.make_request("GET", self.url) - self.render(request) + channel = self.make_request("GET", self.url) return channel.json_body["nonce"] # @@ -227,8 +219,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("nonce must be specified", channel.json_body["error"]) @@ -239,32 +230,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce()}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid username", channel.json_body["error"]) @@ -275,32 +262,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present body = json.dumps({"nonce": nonce(), "username": "a"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid password", channel.json_body["error"]) @@ -318,12 +301,113 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "user_type": "invalid", } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid user type", channel.json_body["error"]) + def test_displayname(self): + """ + Test that displayname of new user is set + """ + + # set no displayname + channel = self.make_request("GET", self.url) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac} + ) + channel = self.make_request("POST", self.url, body.encode("utf8")) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob1:test", channel.json_body["user_id"]) + + channel = self.make_request("GET", "/profile/@bob1:test/displayname") + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("bob1", channel.json_body["displayname"]) + + # displayname is None + channel = self.make_request("GET", self.url) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob2", + "displayname": None, + "password": "abc123", + "mac": want_mac, + } + ) + channel = self.make_request("POST", self.url, body.encode("utf8")) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob2:test", channel.json_body["user_id"]) + + channel = self.make_request("GET", "/profile/@bob2:test/displayname") + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("bob2", channel.json_body["displayname"]) + + # displayname is empty + channel = self.make_request("GET", self.url) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob3", + "displayname": "", + "password": "abc123", + "mac": want_mac, + } + ) + channel = self.make_request("POST", self.url, body.encode("utf8")) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob3:test", channel.json_body["user_id"]) + + channel = self.make_request("GET", "/profile/@bob3:test/displayname") + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + + # set displayname + channel = self.make_request("GET", self.url) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin") + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob4", + "displayname": "Bob's Name", + "password": "abc123", + "mac": want_mac, + } + ) + channel = self.make_request("POST", self.url, body.encode("utf8")) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob4:test", channel.json_body["user_id"]) + + channel = self.make_request("GET", "/profile/@bob4:test/displayname") + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("Bob's Name", channel.json_body["displayname"]) + @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} ) @@ -346,8 +430,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) # Register new user with admin API - request, channel = self.make_request("GET", self.url) - self.render(request) + channel = self.make_request("GET", self.url) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) @@ -366,8 +449,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode("utf8")) - self.render(request) + channel = self.make_request("POST", self.url, body.encode("utf8")) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -385,35 +467,125 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.register_user("user1", "pass1", admin=False) - self.register_user("user2", "pass2", admin=False) + self.user1 = self.register_user( + "user1", "pass1", admin=False, displayname="Name 1" + ) + self.user2 = self.register_user( + "user2", "pass2", admin=False, displayname="Name 2" + ) def test_no_auth(self): """ Try to list users without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user1", "pass1") + + channel = self.make_request("GET", self.url, access_token=other_user_token) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_all_users(self): """ List all users, including deactivated users. """ - request, channel = self.make_request( + channel = self.make_request( "GET", self.url + "?deactivated=true", b"{}", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, channel.json_body["total"]) + # Check that all fields are available + for u in channel.json_body["users"]: + self.assertIn("name", u) + self.assertIn("is_guest", u) + self.assertIn("admin", u) + self.assertIn("user_type", u) + self.assertIn("deactivated", u) + self.assertIn("displayname", u) + self.assertIn("avatar_url", u) + + def test_search_term(self): + """Test that searching for a users works correctly""" + + def _search_test( + expected_user_id: Optional[str], + search_term: str, + search_field: Optional[str] = "name", + expected_http_code: Optional[int] = 200, + ): + """Search for a user and check that the returned user's id is a match + + Args: + expected_user_id: The user_id expected to be returned by the API. Set + to None to expect zero results for the search + search_term: The term to search for user names with + search_field: Field which is to request: `name` or `user_id` + expected_http_code: The expected http code for the request + """ + url = self.url + "?%s=%s" % (search_field, search_term,) + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) + + if expected_http_code != 200: + return + + # Check that users were returned + self.assertTrue("users" in channel.json_body) + users = channel.json_body["users"] + + # Check that the expected number of users were returned + expected_user_count = 1 if expected_user_id else 0 + self.assertEqual(len(users), expected_user_count) + self.assertEqual(channel.json_body["total"], expected_user_count) + + if expected_user_id: + # Check that the first returned user id is correct + u = users[0] + self.assertEqual(expected_user_id, u["name"]) + + # Perform search tests + _search_test(self.user1, "er1") + _search_test(self.user1, "me 1") + + _search_test(self.user2, "er2") + _search_test(self.user2, "me 2") + + _search_test(self.user1, "er1", "user_id") + _search_test(self.user2, "er2", "user_id") + + # Test case insensitive + _search_test(self.user1, "ER1") + _search_test(self.user1, "NAME 1") + + _search_test(self.user2, "ER2") + _search_test(self.user2, "NAME 2") + + _search_test(self.user1, "ER1", "user_id") + _search_test(self.user2, "ER2", "user_id") + + _search_test(None, "foo") + _search_test(None, "bar") + + _search_test(None, "foo", "user_id") + _search_test(None, "bar", "user_id") + class UserRestTestCase(unittest.HomeserverTestCase): @@ -429,7 +601,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.other_user = self.register_user("user", "pass") + self.other_user = self.register_user("user", "pass", displayname="User") self.other_user_token = self.login("user", "pass") self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( self.other_user @@ -441,18 +613,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@bob:test" - request, channel = self.make_request( - "GET", url, access_token=self.other_user_token, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("You are not a server admin", channel.json_body["error"]) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.other_user_token, content=b"{}", ) - self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("You are not a server admin", channel.json_body["error"]) @@ -462,12 +630,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/_synapse/admin/v2/users/@unknown_person:test", access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) @@ -485,17 +652,16 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": True, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - "avatar_url": None, + "avatar_url": "mxc://fibble/wibble", } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -503,12 +669,10 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual(True, channel.json_body["admin"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -518,6 +682,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(True, channel.json_body["admin"]) self.assertEqual(False, channel.json_body["is_guest"]) self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) def test_create_user(self): """ @@ -532,16 +697,16 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": False, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "avatar_url": "mxc://fibble/wibble", } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -549,12 +714,10 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual(False, channel.json_body["admin"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -564,6 +727,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(False, channel.json_body["admin"]) self.assertEqual(False, channel.json_body["is_guest"]) self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} @@ -579,10 +743,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Sync to set admin user to active # before limit of monthly active users is reached - request, channel = self.make_request( - "GET", "/sync", access_token=self.admin_user_tok - ) - self.render(request) + channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok) if channel.code != 200: raise HttpResponseException( @@ -605,13 +766,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Create user body = json.dumps({"password": "abc123", "admin": False}) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -645,13 +805,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Create user body = json.dumps({"password": "abc123", "admin": False}) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) # Admin user is not blocked by mau anymore self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) @@ -683,13 +842,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -701,7 +859,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) pushers = list(pushers) self.assertEqual(len(pushers), 1) - self.assertEqual("@bob:test", pushers[0]["user_name"]) + self.assertEqual("@bob:test", pushers[0].user_name) @override_config( { @@ -728,13 +886,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): } ) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -755,13 +912,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Change password body = json.dumps({"password": "hahaha"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -773,23 +929,21 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Modify user body = json.dumps({"displayname": "foobar"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -805,13 +959,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]} ) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -819,10 +972,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -837,13 +989,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Deactivate user body = json.dumps({"deactivated": True}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -851,28 +1002,74 @@ class UserRestTestCase(unittest.HomeserverTestCase): # the user is deactivated, the threepid will be deleted # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) + @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) + def test_change_name_deactivate_user_user_directory(self): + """ + Test change profile information of a deactivated user and + check that it does not appear in user directory + """ + + # is in user directory + profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + self.assertTrue(profile["display_name"] == "User") + + # Deactivate user + body = json.dumps({"deactivated": True}) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + + # is not in user directory + profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + self.assertTrue(profile is None) + + # Set new displayname user + body = json.dumps({"displayname": "Foobar"}) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(True, channel.json_body["deactivated"]) + self.assertEqual("Foobar", channel.json_body["displayname"]) + + # is not in user directory + profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + self.assertTrue(profile is None) + def test_reactivate_user(self): """ Test reactivating another user. """ # Deactivate the user. - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, content=json.dumps({"deactivated": True}).encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self._is_erased("@user:test", False) d = self.store.mark_user_erased("@user:test") @@ -880,17 +1077,16 @@ class UserRestTestCase(unittest.HomeserverTestCase): self._is_erased("@user:test", True) # Attempt to reactivate the user (without a password). - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) # Reactivate the user. - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, @@ -898,14 +1094,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): encoding="utf_8" ), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -920,23 +1114,21 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set a user as an admin body = json.dumps({"admin": True}) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["admin"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_other_user, access_token=self.admin_user_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) @@ -952,23 +1144,19 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Create user body = json.dumps({"password": "abc123"}) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) # Get user - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -978,21 +1166,17 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Change password (and use a str for deactivate instead of a bool) body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops! - request, channel = self.make_request( + channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, content=body.encode(encoding="utf_8"), ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) # Check user is not deactivated - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1016,7 +1200,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, - sync.register_servlets, room.register_servlets, ] @@ -1035,8 +1218,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ Try to list rooms of an user without authentication. """ - request, channel = self.make_request("GET", self.url, b"{}") - self.render(request) + channel = self.make_request("GET", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -1047,10 +1229,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - request, channel = self.make_request( - "GET", self.url, access_token=other_user_token, - ) - self.render(request) + channel = self.make_request("GET", self.url, access_token=other_user_token,) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1060,10 +1239,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1074,14 +1250,23 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" - request, channel = self.make_request( - "GET", url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) + def test_no_memberships(self): + """ + Tests that a normal lookup for rooms is successfully + if user has no memberships + """ + # Get rooms + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) + def test_get_rooms(self): """ Tests that a normal lookup for rooms is successfully @@ -1093,11 +1278,675 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.helper.create_room_as(self.other_user, tok=other_user_tok) # Get rooms - request, channel = self.make_request( - "GET", self.url, access_token=self.admin_user_tok, - ) - self.render(request) + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) + + +class PushersRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/pushers" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list pushers of an user without authentication. + """ + channel = self.make_request("GET", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + channel = self.make_request("GET", self.url, access_token=other_user_token,) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" + + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_get_pushers(self): + """ + Tests that a normal lookup for pushers is successfully + """ + + # Get pushers + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + # Register the pusher + other_user_token = self.login("user", "pass") + user_tuple = self.get_success( + self.store.get_user_by_access_token(other_user_token) + ) + token_id = user_tuple.token_id + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=self.other_user, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "https://example.com/_matrix/push/v1/notify"}, + ) + ) + + # Get pushers + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + + for p in channel.json_body["pushers"]: + self.assertIn("pushkey", p) + self.assertIn("kind", p) + self.assertIn("app_id", p) + self.assertIn("app_display_name", p) + self.assertIn("device_display_name", p) + self.assertIn("profile_tag", p) + self.assertIn("lang", p) + self.assertIn("url", p["data"]) + + +class UserMediaRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.media_repo = hs.get_media_repository_resource() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/media" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list media of an user without authentication. + """ + channel = self.make_request("GET", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + channel = self.make_request("GET", self.url, access_token=other_user_token,) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/media" + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" + + channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_limit(self): + """ + Testing list of media with limit + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 5) + self.assertEqual(channel.json_body["next_token"], 5) + self._check_fields(channel.json_body["media"]) + + def test_from(self): + """ + Testing list of media with a defined starting point (from) + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["media"]) + + def test_limit_and_from(self): + """ + Testing list of media with a defined starting point and limit + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(channel.json_body["next_token"], 15) + self.assertEqual(len(channel.json_body["media"]), 10) + self._check_fields(channel.json_body["media"]) + + def test_limit_is_negative(self): + """ + Testing that a negative limit parameter returns a 400 + """ + + channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_from_is_negative(self): + """ + Testing that a negative from parameter returns a 400 + """ + + channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), number_media) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), number_media) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 19) + self.assertEqual(channel.json_body["next_token"], 19) + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def test_user_has_no_media(self): + """ + Tests that a normal lookup for media is successfully + if user has no media created + """ + + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["media"])) + + def test_get_media(self): + """ + Tests that a normal lookup for media is successfully + """ + + number_media = 5 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_media, channel.json_body["total"]) + self.assertEqual(number_media, len(channel.json_body["media"])) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["media"]) + + def _create_media(self, user_token, number_media): + """ + Create a number of media for a specific user + """ + upload_resource = self.media_repo.children[b"upload"] + for i in range(number_media): + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + self.helper.upload_media( + upload_resource, image_data, tok=user_token, expect_code=200 + ) + + def _check_fields(self, content): + """Checks that all attributes are present in content + """ + for m in content: + self.assertIn("media_id", m) + self.assertIn("media_type", m) + self.assertIn("media_length", m) + self.assertIn("upload_name", m) + self.assertIn("created_ts", m) + self.assertIn("last_access_ts", m) + self.assertIn("quarantined_by", m) + self.assertIn("safe_from_quarantine", m) + + +class UserTokenRestTestCase(unittest.HomeserverTestCase): + """Test for /_synapse/admin/v1/users/<user>/login + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + room.register_servlets, + devices.register_servlets, + logout.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/login" % urllib.parse.quote( + self.other_user + ) + + def _get_token(self) -> str: + channel = self.make_request( + "POST", self.url, b"{}", access_token=self.admin_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + return channel.json_body["access_token"] + + def test_no_auth(self): + """Try to login as a user without authentication. + """ + channel = self.make_request("POST", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_not_admin(self): + """Try to login as a user as a non-admin user. + """ + channel = self.make_request( + "POST", self.url, b"{}", access_token=self.other_user_tok + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + + def test_send_event(self): + """Test that sending event as a user works. + """ + # Create a room. + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok) + + # Login in as the user + puppet_token = self._get_token() + + # Test that sending works, and generates the event as the right user. + resp = self.helper.send_event(room_id, "com.example.test", tok=puppet_token) + event_id = resp["event_id"] + event = self.get_success(self.store.get_event(event_id)) + self.assertEqual(event.sender, self.other_user) + + def test_devices(self): + """Tests that logging in as a user doesn't create a new device for them. + """ + # Login in as the user + self._get_token() + + # Check that we don't see a new device in our devices list + channel = self.make_request( + "GET", "devices", b"{}", access_token=self.other_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # We should only see the one device (from the login in `prepare`) + self.assertEqual(len(channel.json_body["devices"]), 1) + + def test_logout(self): + """Test that calling `/logout` with the token works. + """ + # Login in as the user + puppet_token = self._get_token() + + # Test that we can successfully make a request + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Logout with the puppet token + channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # The puppet token should no longer work + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + + # .. but the real user's tokens should still work + channel = self.make_request( + "GET", "devices", b"{}", access_token=self.other_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def test_user_logout_all(self): + """Tests that the target user calling `/logout/all` does *not* expire + the token. + """ + # Login in as the user + puppet_token = self._get_token() + + # Test that we can successfully make a request + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Logout all with the real user token + channel = self.make_request( + "POST", "logout/all", b"{}", access_token=self.other_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # The puppet token should still work + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # .. but the real user's tokens shouldn't + channel = self.make_request( + "GET", "devices", b"{}", access_token=self.other_user_tok + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + + def test_admin_logout_all(self): + """Tests that the admin user calling `/logout/all` does expire the + token. + """ + # Login in as the user + puppet_token = self._get_token() + + # Test that we can successfully make a request + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Logout all with the admin user token + channel = self.make_request( + "POST", "logout/all", b"{}", access_token=self.admin_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # The puppet token should no longer work + channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + + # .. but the real user's tokens should still work + channel = self.make_request( + "GET", "devices", b"{}", access_token=self.other_user_tok + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + @unittest.override_config( + { + "public_baseurl": "https://example.org/", + "user_consent": { + "version": "1.0", + "policy_name": "My Cool Privacy Policy", + "template_dir": "/", + "require_at_registration": True, + "block_events_error": "You should accept the policy", + }, + "form_secret": "123secret", + } + ) + def test_consent(self): + """Test that sending a message is not subject to the privacy policies. + """ + # Have the admin user accept the terms. + self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0")) + + # First, cheekily accept the terms and create a room + self.get_success(self.store.user_set_consent_version(self.other_user, "1.0")) + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok) + self.helper.send_event(room_id, "com.example.test", tok=self.other_user_tok) + + # Now unaccept it and check that we can't send an event + self.get_success(self.store.user_set_consent_version(self.other_user, "0.0")) + self.helper.send_event( + room_id, "com.example.test", tok=self.other_user_tok, expect_code=403 + ) + + # Login in as the user + puppet_token = self._get_token() + + # Sending an event on their behalf should work fine + self.helper.send_event(room_id, "com.example.test", tok=puppet_token) + + @override_config( + {"limit_usage_by_mau": True, "max_mau_value": 1, "mau_trial_days": 0} + ) + def test_mau_limit(self): + # Create a room as the admin user. This will bump the monthly active users to 1. + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + # Trying to join as the other user should fail due to reaching MAU limit. + self.helper.join( + room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403 + ) + + # Logging in as the other user and joining a room should work, even + # though the MAU limit would stop the user doing so. + puppet_token = self._get_token() + self.helper.join(room_id, user=self.other_user, tok=puppet_token) + + +class WhoisRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url1 = "/_synapse/admin/v1/whois/%s" % urllib.parse.quote(self.other_user) + self.url2 = "/_matrix/client/r0/admin/whois/%s" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to get information of an user without authentication. + """ + channel = self.make_request("GET", self.url1, b"{}") + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + channel = self.make_request("GET", self.url2, b"{}") + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.register_user("user2", "pass") + other_user2_token = self.login("user2", "pass") + + channel = self.make_request("GET", self.url1, access_token=other_user2_token,) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + channel = self.make_request("GET", self.url2, access_token=other_user2_token,) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" + url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain" + + channel = self.make_request("GET", url1, access_token=self.admin_user_tok,) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only whois a local user", channel.json_body["error"]) + + channel = self.make_request("GET", url2, access_token=self.admin_user_tok,) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only whois a local user", channel.json_body["error"]) + + def test_get_whois_admin(self): + """ + The lookup should succeed for an admin. + """ + channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + self.assertIn("devices", channel.json_body) + + channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + self.assertIn("devices", channel.json_body) + + def test_get_whois_user(self): + """ + The lookup should succeed for a normal user looking up their own information. + """ + other_user_token = self.login("user", "pass") + + channel = self.make_request("GET", self.url1, access_token=other_user_token,) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + self.assertIn("devices", channel.json_body) + + channel = self.make_request("GET", self.url2, access_token=other_user_token,) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + self.assertIn("devices", channel.json_body) diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 6803b372ac..c74693e9b2 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py
@@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.consent import consent_resource from tests import unittest -from tests.server import render +from tests.server import FakeSite, make_request class ConsentResourceTestCase(unittest.HomeserverTestCase): @@ -61,8 +61,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): def test_render_public_consent(self): """You can observe the terms form without specifying a user""" resource = consent_resource.ConsentResource(self.hs) - request, channel = self.make_request("GET", "/consent?v=1", shorthand=False) - render(request, resource, self.reactor) + channel = make_request( + self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False + ) self.assertEqual(channel.code, 200) def test_accept_consent(self): @@ -81,10 +82,14 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "") + "&u=user" ) - request, channel = self.make_request( - "GET", consent_uri, access_token=access_token, shorthand=False + channel = make_request( + self.reactor, + FakeSite(resource), + "GET", + consent_uri, + access_token=access_token, + shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Get the version from the body, and whether we've consented @@ -92,21 +97,26 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): self.assertEqual(consented, "False") # POST to the consent page, saying we've agreed - request, channel = self.make_request( + channel = make_request( + self.reactor, + FakeSite(resource), "POST", consent_uri + "&v=" + version, access_token=access_token, shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Fetch the consent page, to get the consent version -- it should have # changed - request, channel = self.make_request( - "GET", consent_uri, access_token=access_token, shorthand=False + channel = make_request( + self.reactor, + FakeSite(resource), + "GET", + consent_uri, + access_token=access_token, + shorthand=False, ) - render(request, resource, self.reactor) self.assertEqual(channel.code, 200) # Get the version from the body, and check that it's the version we diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index 5e9c07ebf3..56937dcd2e 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py
@@ -93,8 +93,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase): def get_event(self, room_id, event_id, expected_code=200): url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) - request, channel = self.make_request("GET", url) - self.render(request) + channel = self.make_request("GET", url) self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..c0a9fc6925 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py
@@ -43,10 +43,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") - request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=tok - ) - self.render(request) + channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) room_id = channel.json_body["room_id"] @@ -57,8 +54,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") - request, channel = self.make_request( + channel = self.make_request( b"POST", request_url, request_data, access_token=tok ) - self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index d2bcf256fa..f0707646bb 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py
@@ -69,18 +69,12 @@ class RedactionsTestCase(HomeserverTestCase): """ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) - request, channel = self.make_request( - "POST", path, content={}, access_token=access_token - ) - self.render(request) + channel = self.make_request("POST", path, content={}, access_token=access_token) self.assertEqual(int(channel.result["code"]), expect_code) return channel.json_body def _sync_room_timeline(self, access_token, room_id): - request, channel = self.make_request( - "GET", "sync", access_token=self.mod_access_token - ) - self.render(request) + channel = self.make_request("GET", "sync", access_token=self.mod_access_token) self.assertEqual(channel.result["code"], b"200") room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 7d3773ff78..31dc832fd5 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -325,8 +325,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def get_event(self, room_id, event_id, expected_code=200): url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) - request, channel = self.make_request("GET", url, access_token=self.token) - self.render(request) + channel = self.make_request("GET", url, access_token=self.token) self.assertEqual(channel.code, expected_code, channel.result) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index dfe4bf7762..e689c3fbea 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py
@@ -78,7 +78,7 @@ class RoomTestCase(_ShadowBannedBase): def test_invite_3pid(self): """Ensure that a 3PID invite does not attempt to contact the identity server.""" - identity_handler = self.hs.get_handlers().identity_handler + identity_handler = self.hs.get_identity_handler() identity_handler.lookup_3pid = Mock( side_effect=AssertionError("This should not get called") ) @@ -89,13 +89,12 @@ class RoomTestCase(_ShadowBannedBase): ) # Inviting the user completes successfully. - request, channel = self.make_request( + channel = self.make_request( "POST", "/rooms/%s/invite" % (room_id,), {"id_server": "test", "medium": "email", "address": "test@test.test"}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) # This should have raised an error earlier, but double check this wasn't called. @@ -104,13 +103,12 @@ class RoomTestCase(_ShadowBannedBase): def test_create_room(self): """Invitations during a room creation should be discarded, but the room still gets created.""" # The room creation is successful. - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/createRoom", {"visibility": "public", "invite": [self.other_user_id]}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) room_id = channel.json_body["room_id"] @@ -160,13 +158,12 @@ class RoomTestCase(_ShadowBannedBase): self.banned_user_id, tok=self.banned_access_token ) - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,), {"new_version": "6"}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) # A new room_id should be returned. self.assertIn("replacement_room", channel.json_body) @@ -186,13 +183,12 @@ class RoomTestCase(_ShadowBannedBase): self.banned_user_id, tok=self.banned_access_token ) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (room_id, self.banned_user_id), {"typing": True, "timeout": 30000}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code) # There should be no typing events. @@ -202,13 +198,12 @@ class RoomTestCase(_ShadowBannedBase): # The other user can join and send typing events. self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (room_id, self.other_user_id), {"typing": True, "timeout": 30000}, access_token=self.other_access_token, ) - self.render(request) self.assertEquals(200, channel.code) # These appear in the room. @@ -249,21 +244,19 @@ class ProfileTestCase(_ShadowBannedBase): ) # The update should succeed. - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,), {"displayname": new_display_name}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertEqual(channel.json_body, {}) # The user's display name should be updated. - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/%s/displayname" % (self.banned_user_id,) ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.json_body["displayname"], new_display_name) @@ -289,14 +282,13 @@ class ProfileTestCase(_ShadowBannedBase): ) # The update should succeed. - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room_id, self.banned_user_id), {"membership": "join", "displayname": new_display_name}, access_token=self.banned_access_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertIn("event_id", channel.json_body) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py new file mode 100644
index 0000000000..227fffab58 --- /dev/null +++ b/tests/rest/client/test_third_party_rules.py
@@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +from typing import Dict + +from mock import Mock + +from synapse.events import EventBase +from synapse.module_api import ModuleApi +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.types import Requester, StateMap + +from tests import unittest + +thread_local = threading.local() + + +class ThirdPartyRulesTestModule: + def __init__(self, config: Dict, module_api: ModuleApi): + # keep a record of the "current" rules module, so that the test can patch + # it if desired. + thread_local.rules_module = self + self.module_api = module_api + + async def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ): + return True + + async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + return True + + @staticmethod + def parse_config(config): + return config + + +def current_rules_module() -> ThirdPartyRulesTestModule: + return thread_local.rules_module + + +class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["third_party_event_rules"] = { + "module": __name__ + ".ThirdPartyRulesTestModule", + "config": {}, + } + return config + + def prepare(self, reactor, clock, homeserver): + # Create a user and room to play with during the tests + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_third_party_rules(self): + """Tests that a forbidden event is forbidden from being sent, but an allowed one + can be sent. + """ + # patch the rules module with a Mock which will return False for some event + # types + async def check(ev, state): + return ev.type != "foo.bar.forbidden" + + callback = Mock(spec=[], side_effect=check) + current_rules_module().check_event_allowed = callback + + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id, + {}, + access_token=self.tok, + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + callback.assert_called_once() + + # there should be various state events in the state arg: do some basic checks + state_arg = callback.call_args[0][1] + for k in (("m.room.create", ""), ("m.room.member", self.user_id)): + self.assertIn(k, state_arg) + ev = state_arg[k] + self.assertEqual(ev.type, k[0]) + self.assertEqual(ev.state_key, k[1]) + + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id, + {}, + access_token=self.tok, + ) + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_cannot_modify_event(self): + """cannot accidentally modify an event before it is persisted""" + + # first patch the event checker so that it will try to modify the event + async def check(ev: EventBase, state): + ev.content = {"x": "y"} + return True + + current_rules_module().check_event_allowed = check + + # now send the event + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id, + {"x": "x"}, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"500", channel.result) + + def test_modify_event(self): + """The module can return a modified version of the event""" + # first patch the event checker so that it will modify the event + async def check(ev: EventBase, state): + d = ev.get_dict() + d["content"] = {"x": "y"} + return d + + current_rules_module().check_event_allowed = check + + # now send the event + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id, + {"x": "x"}, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + event_id = channel.json_body["event_id"] + + # ... and check that it got modified + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["x"], "y") + + def test_send_event(self): + """Tests that the module can send an event into a room via the module api""" + content = { + "msgtype": "m.text", + "body": "Hello!", + } + event_dict = { + "room_id": self.room_id, + "type": "m.room.message", + "content": content, + "sender": self.user_id, + } + event = self.get_success( + current_rules_module().module_api.create_and_send_event_into_room( + event_dict + ) + ) # type: EventBase + + self.assertEquals(event.sender, self.user_id) + self.assertEquals(event.room_id, self.room_id) + self.assertEquals(event.type, "m.room.message") + self.assertEquals(event.content, content) diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py deleted file mode 100644
index 8c24add530..0000000000 --- a/tests/rest/client/third_party_rules.py +++ /dev/null
@@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.rest import admin -from synapse.rest.client.v1 import login, room - -from tests import unittest - - -class ThirdPartyRulesTestModule: - def __init__(self, config): - pass - - def check_event_allowed(self, event, context): - if event.type == "foo.bar.forbidden": - return False - else: - return True - - @staticmethod - def parse_config(config): - return config - - -class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["third_party_event_rules"] = { - "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule", - "config": {}, - } - - self.hs = self.setup_test_homeserver(config=config) - return self.hs - - def test_third_party_rules(self): - """Tests that a forbidden event is forbidden from being sent, but an allowed one - can be sent. - """ - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - room_id = self.helper.create_room_as(user_id, tok=tok) - - request, channel = self.make_request( - "PUT", - "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id, - {}, - access_token=tok, - ) - self.render(request) - self.assertEquals(channel.result["code"], b"200", channel.result) - - request, channel = self.make_request( - "PUT", - "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id, - {}, - access_token=tok, - ) - self.render(request) - self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index 633b7dbda0..edd1d184f8 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py
@@ -21,6 +21,7 @@ from synapse.types import RoomAlias from synapse.util.stringutils import random_string from tests import unittest +from tests.unittest import override_config class DirectoryTestCase(unittest.HomeserverTestCase): @@ -67,10 +68,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.ensure_user_joined_room() self.set_alias_via_directory(400, alias_length=256) - def test_state_event_in_room(self): + @override_config({"default_room_version": 5}) + def test_state_event_user_in_v5_room(self): + """Test that a regular user can add alias events before room v6""" self.ensure_user_joined_room() self.set_alias_via_state_event(200) + @override_config({"default_room_version": 6}) + def test_state_event_v6_room(self): + """Test that a regular user can *not* add alias events from room v6""" + self.ensure_user_joined_room() + self.set_alias_via_state_event(403) + def test_directory_in_room(self): self.ensure_user_joined_room() self.set_alias_via_directory(200) @@ -82,10 +91,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # that we can make sure that the check is done on the whole alias. data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.render(request) self.assertEqual(channel.code, 400, channel.result) def test_room_creation(self): @@ -96,10 +104,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # as cautious as possible here. data = {"room_alias_name": random_string(5)} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) def set_alias_via_state_event(self, expected_code, alias_length=5): @@ -111,10 +118,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase): data = {"aliases": [self.random_alias(alias_length)]} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def set_alias_via_directory(self, expected_code, alias_length=5): @@ -122,10 +128,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase): data = {"room_id": self.room_id} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def random_alias(self, length): diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index f75520877f..0a5ca317ea 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py
@@ -42,7 +42,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) - hs.get_handlers().federation_handler = Mock() + hs.get_federation_handler = Mock() return hs @@ -63,17 +63,15 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): # implementation is now part of the r0 implementation, the newer # behaviour is used instead to be consistent with the r0 spec. # see issue #2602 - request, channel = self.make_request( + channel = self.make_request( "GET", "/events?access_token=%s" % ("invalid" + self.token,) ) - self.render(request) self.assertEquals(channel.code, 401, msg=channel.result) # valid token, expect content - request, channel = self.make_request( + channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.render(request) self.assertEquals(channel.code, 200, msg=channel.result) self.assertTrue("chunk" in channel.json_body) self.assertTrue("start" in channel.json_body) @@ -89,10 +87,9 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): ) # valid token, expect content - request, channel = self.make_request( + channel = self.make_request( "GET", "/events?access_token=%s&timeout=0" % (self.token,) ) - self.render(request) self.assertEquals(channel.code, 200, msg=channel.result) # We may get a presence event for ourselves down @@ -152,8 +149,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase): resp = self.helper.send(self.room_id, tok=self.token) event_id = resp["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/events/" + event_id, access_token=self.token, ) - self.render(request) self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 5d987a30c7..18932d7518 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py
@@ -63,8 +63,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) + channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -83,8 +82,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -110,8 +108,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) + channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -130,8 +127,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -157,8 +153,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) + channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -177,8 +172,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.result["code"], b"403", channel.result) @@ -187,8 +181,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.register_user("kermit", "monkey") # we shouldn't be able to make requests without an access token - request, channel = self.make_request(b"GET", TEST_URL) - self.render(request) + channel = self.make_request(b"GET", TEST_URL) self.assertEquals(channel.result["code"], b"401", channel.result) self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") @@ -198,28 +191,21 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEquals(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) - self.render(request) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) - self.render(request) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) @@ -233,10 +219,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) - self.render(request) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) @@ -244,20 +227,16 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # ... but if we delete that device, it will be a proper logout self._delete_device(access_token_2, "kermit", "monkey", device_id) - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) - self.render(request) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], False) def _delete_device(self, access_token, user_id, password, device_id): """Perform the UI-Auth to delete a device""" - request, channel = self.make_request( + channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) - self.render(request) self.assertEquals(channel.code, 401, channel.result) # check it's a UI-Auth fail self.assertEqual( @@ -275,13 +254,12 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "session": channel.json_body["session"], } - request, channel = self.make_request( + channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token, content={"auth": auth}, ) - self.render(request) self.assertEquals(channel.code, 200, channel.result) @override_config({"session_lifetime": "24h"}) @@ -292,29 +270,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token = self.login("kermit", "monkey") # we should now be able to make requests with the access token - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) - self.render(request) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) - self.render(request) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) # Now try to hard logout this session - request, channel = self.make_request( - b"POST", "/logout", access_token=access_token - ) - self.render(request) + channel = self.make_request(b"POST", "/logout", access_token=access_token) self.assertEquals(channel.result["code"], b"200", channel.result) @override_config({"session_lifetime": "24h"}) @@ -325,29 +294,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token = self.login("kermit", "monkey") # we should now be able to make requests with the access token - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) - self.render(request) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 200, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted - request, channel = self.make_request( - b"GET", TEST_URL, access_token=access_token - ) - self.render(request) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) self.assertEquals(channel.code, 401, channel.result) self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEquals(channel.json_body["soft_logout"], True) # Now try to hard log out all of the user's sessions - request, channel = self.make_request( - b"POST", "/logout/all", access_token=access_token - ) - self.render(request) + channel = self.make_request(b"POST", "/logout/all", access_token=access_token) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -422,8 +382,7 @@ class CASTestCase(unittest.HomeserverTestCase): cas_ticket_url = urllib.parse.urlunparse(url_parts) # Get Synapse to call the fake CAS and serve the template. - request, channel = self.make_request("GET", cas_ticket_url) - self.render(request) + channel = self.make_request("GET", cas_ticket_url) # Test that the response is HTML. self.assertEqual(channel.code, 200) @@ -467,8 +426,7 @@ class CASTestCase(unittest.HomeserverTestCase): ) # Get Synapse to call the fake CAS and serve the template. - request, channel = self.make_request("GET", cas_ticket_url) - self.render(request) + channel = self.make_request("GET", cas_ticket_url) self.assertEqual(channel.code, 302) location_headers = channel.headers.getRawHeaders("Location") @@ -494,8 +452,7 @@ class CASTestCase(unittest.HomeserverTestCase): ) # Get Synapse to call the fake CAS and serve the template. - request, channel = self.make_request("GET", cas_ticket_url) - self.render(request) + channel = self.make_request("GET", cas_ticket_url) # Because the user is deactivated they are served an error template. self.assertEqual(channel.code, 403) @@ -518,15 +475,18 @@ class JWTTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = self.jwt_algorithm return self.hs - def jwt_encode(self, token, secret=jwt_secret): - return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii") + def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result = jwt.encode(token, secret, self.jwt_algorithm) + if isinstance(result, bytes): + return result.decode("ascii") + return result def jwt_login(self, *args): params = json.dumps( {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} ) - request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) + channel = self.make_request(b"POST", LOGIN_URL, params) return channel def test_login_jwt_valid_registered(self): @@ -658,8 +618,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_no_token(self): params = json.dumps({"type": "org.matrix.login.jwt"}) - request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) + channel = self.make_request(b"POST", LOGIN_URL, params) self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -725,15 +684,18 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): self.hs.config.jwt_algorithm = "RS256" return self.hs - def jwt_encode(self, token, secret=jwt_privatekey): - return jwt.encode(token, secret, "RS256").decode("ascii") + def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result = jwt.encode(token, secret, "RS256") + if isinstance(result, bytes): + return result.decode("ascii") + return result def jwt_login(self, *args): params = json.dumps( {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} ) - request, channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) + channel = self.make_request(b"POST", LOGIN_URL, params) return channel def test_login_jwt_valid(self): @@ -761,12 +723,11 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): ] def register_as_user(self, username): - request, channel = self.make_request( + self.make_request( b"POST", "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), {"username": username}, ) - self.render(request) def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() @@ -811,11 +772,10 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) def test_login_appservice_user_bot(self): @@ -827,11 +787,10 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": self.service.sender}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) def test_login_appservice_wrong_user(self): @@ -843,11 +802,10 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) def test_login_appservice_wrong_as(self): @@ -859,11 +817,10 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - request, channel = self.make_request( + channel = self.make_request( b"POST", LOGIN_URL, params, access_token=self.another_service.token ) - self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) def test_login_appservice_no_token(self): @@ -876,7 +833,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): "type": login.LoginRestServlet.APPSERVICE_TYPE, "identifier": {"type": "m.id.user", "user": AS_USER}, } - request, channel = self.make_request(b"POST", LOGIN_URL, params) + channel = self.make_request(b"POST", LOGIN_URL, params) - self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 3c66255dac..94a5154834 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py
@@ -33,13 +33,16 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): + presence_handler = Mock() + presence_handler.set_state.return_value = defer.succeed(None) + hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock() + "red", + federation_http_client=None, + federation_client=Mock(), + presence_handler=presence_handler, ) - hs.presence_handler = Mock() - hs.presence_handler.set_state.return_value = defer.succeed(None) - return hs def test_put_presence(self): @@ -50,13 +53,12 @@ class PresenceTestCase(unittest.HomeserverTestCase): self.hs.config.use_presence = True body = {"presence": "here", "status_msg": "beep boop"} - request, channel = self.make_request( + channel = self.make_request( "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.render(request) self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.presence_handler.set_state.call_count, 1) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) def test_put_presence_disabled(self): """ @@ -66,10 +68,9 @@ class PresenceTestCase(unittest.HomeserverTestCase): self.hs.config.use_presence = False body = {"presence": "here", "status_msg": "beep boop"} - request, channel = self.make_request( + channel = self.make_request( "PUT", "/presence/%s/status" % (self.user_id,), body ) - self.render(request) self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.presence_handler.set_state.call_count, 0) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index ace0a3c08d..e59fa70baa 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py
@@ -63,7 +63,7 @@ class MockHandlerProfileTestCase(unittest.TestCase): hs = yield setup_test_homeserver( self.addCleanup, "test", - http_client=None, + federation_http_client=None, resource_for_client=self.mock_resource, federation=Mock(), federation_client=Mock(), @@ -189,13 +189,12 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.owner_tok = self.login("owner", "pass") def test_set_displayname(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test"}), access_token=self.owner_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) res = self.get_displayname() @@ -203,23 +202,19 @@ class ProfileTestCase(unittest.HomeserverTestCase): def test_set_displayname_too_long(self): """Attempts to set a stupid displayname should get a 400""" - request, channel = self.make_request( + channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test" * 100}), access_token=self.owner_tok, ) - self.render(request) self.assertEqual(channel.code, 400, channel.result) res = self.get_displayname() self.assertEqual(res, "owner") def get_displayname(self): - request, channel = self.make_request( - "GET", "/profile/%s/displayname" % (self.owner,) - ) - self.render(request) + channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,)) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] @@ -281,10 +276,9 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): ) def request_profile(self, expected_code, url_suffix="", access_token=None): - request, channel = self.make_request( + channel = self.make_request( "GET", self.profile_url + url_suffix, access_token=access_token ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def ensure_requester_left_room(self): @@ -324,24 +318,21 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): """Tests that a user can lookup their own profile without having to be in a room if 'require_auth_for_profile_requests' is set to true in the server's config. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/" + self.requester, access_token=self.requester_tok ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/" + self.requester + "/displayname", access_token=self.requester_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) - request, channel = self.make_request( + channel = self.make_request( "GET", "/profile/" + self.requester + "/avatar_url", access_token=self.requester_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
index 081052f6a6..2bc512d75e 100644 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -45,17 +45,15 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # GET enabled for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) @@ -76,49 +74,43 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # disable the rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": False}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 200) # check rule disabled - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], False) # DELETE the rule - request, channel = self.make_request( + channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # GET enabled for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) @@ -138,45 +130,40 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # disable the rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": False}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 200) # check rule disabled - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], False) # re-enable the rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": True}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 200) # check rule enabled - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["enabled"], True) @@ -195,39 +182,34 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # GET enabled for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # DELETE the rule - request, channel = self.make_request( + channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # check 404 for deleted rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -239,10 +221,9 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -254,13 +235,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/enabled", {"enabled": True}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -272,13 +252,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/.m.muahahah/enabled", {"enabled": True}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -297,17 +276,15 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # GET actions for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/actions", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] @@ -328,27 +305,24 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # change the rule actions - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/actions", {"actions": ["dont_notify"]}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 200) # GET actions for that new rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/actions", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["actions"], ["dont_notify"]) @@ -367,32 +341,28 @@ class PushRuleAttributesTestCase(HomeserverTestCase): } # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # PUT a new rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend", body, access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # DELETE the rule - request, channel = self.make_request( + channel = self.make_request( "DELETE", "/pushrules/global/override/best.friend", access_token=token ) - self.render(request) self.assertEqual(channel.code, 200) # check 404 for deleted rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/best.friend/enabled", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -404,10 +374,9 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -419,13 +388,12 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/best.friend/actions", {"actions": ["dont_notify"]}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -437,12 +405,11 @@ class PushRuleAttributesTestCase(HomeserverTestCase): token = self.login("user", "pass") # enable & check 404 for never-heard-of rule - request, channel = self.make_request( + channel = self.make_request( "PUT", "/pushrules/global/override/.m.muahahah/actions", {"actions": ["dont_notify"]}, access_token=token, ) - self.render(request) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 0d809d25d5..6105eac47c 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py
@@ -26,12 +26,14 @@ from mock import Mock import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus +from synapse.rest import admin from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string from tests import unittest +from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -44,10 +46,13 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock(), + "red", federation_http_client=None, federation_client=Mock(), ) - self.hs.get_federation_handler = Mock(return_value=Mock()) + self.hs.get_federation_handler = Mock() + self.hs.get_federation_handler.return_value.maybe_backfill = Mock( + return_value=make_awaitable(None) + ) async def _insert_client_ip(*args, **kwargs): return None @@ -79,19 +84,17 @@ class RoomPermissionsTestCase(RoomBase): self.created_rmid_msg_path = ( "rooms/%s/send/m.room.message/a1" % (self.created_rmid) ).encode("ascii") - request, channel = self.make_request( + channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.render(request) self.assertEquals(200, channel.code, channel.result) # set topic for public room - request, channel = self.make_request( + channel = self.make_request( "PUT", ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.render(request) self.assertEquals(200, channel.code, channel.result) # auth as user_id now @@ -109,37 +112,32 @@ class RoomPermissionsTestCase(RoomBase): ) # send message in uncreated room, expect 403 - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 - request, channel = self.make_request("PUT", send_msg_path(), msg_content) - self.render(request) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = self.make_request("PUT", send_msg_path(), msg_content) - self.render(request) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) - request, channel = self.make_request("PUT", send_msg_path(), msg_content) - self.render(request) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = self.make_request("PUT", send_msg_path(), msg_content) - self.render(request) + channel = self.make_request("PUT", send_msg_path(), msg_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_topic_perms(self): @@ -147,36 +145,30 @@ class RoomPermissionsTestCase(RoomBase): topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid # set/get topic in uncreated room, expect 403 - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 - request, channel = self.make_request("PUT", topic_path, topic_content) - self.render(request) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", topic_path) - self.render(request) + channel = self.make_request("GET", topic_path) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = self.make_request("PUT", topic_path, topic_content) - self.render(request) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 - request, channel = self.make_request("GET", topic_path) - self.render(request) + channel = self.make_request("GET", topic_path) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 @@ -184,46 +176,39 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id - request, channel = self.make_request("PUT", topic_path, topic_content) - self.render(request) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id - request, channel = self.make_request("GET", topic_path) - self.render(request) + channel = self.make_request("GET", topic_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = self.make_request("PUT", topic_path, topic_content) - self.render(request) + channel = self.make_request("PUT", topic_path, topic_content) self.assertEquals(403, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", topic_path) - self.render(request) + channel = self.make_request("GET", topic_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.render(request) self.assertEquals(403, channel.code, msg=channel.result["body"]) def _test_get_membership(self, room=None, members=[], expect_code=None): for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) - request, channel = self.make_request("GET", path) - self.render(request) + channel = self.make_request("GET", path) self.assertEquals(expect_code, channel.code) def test_membership_basic_room_perms(self): @@ -395,19 +380,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.render(request) + channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEquals(200, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self): - request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.render(request) + channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self): room_id = self.helper.create_room_as("@some_other_guy:red") - request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.render(request) + channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self): @@ -416,20 +398,17 @@ class RoomsMemberListTestCase(RoomBase): room_path = "/rooms/%s/members" % room_id self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. - request, channel = self.make_request("GET", room_path) - self.render(request) + channel = self.make_request("GET", room_path) self.assertEquals(403, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined - request, channel = self.make_request("GET", room_path) - self.render(request) + channel = self.make_request("GET", room_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left - request, channel = self.make_request("GET", room_path) - self.render(request) + channel = self.make_request("GET", room_path) self.assertEquals(200, channel.code, msg=channel.result["body"]) @@ -440,56 +419,45 @@ class RoomsCreateTestCase(RoomBase): def test_post_room_no_keys(self): # POST with no config keys, expect new room id - request, channel = self.make_request("POST", "/createRoom", "{}") + channel = self.make_request("POST", "/createRoom", "{}") - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id - request, channel = self.make_request( - "POST", "/createRoom", b'{"visibility":"private"}' - ) - self.render(request) + channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self): # POST with custom config keys, expect new room id - request, channel = self.make_request( - "POST", "/createRoom", b'{"custom":"stuff"}' - ) - self.render(request) + channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self): # POST with custom + known config keys, expect new room id - request, channel = self.make_request( + channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.render(request) self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 - request, channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.render(request) + channel = self.make_request("POST", "/createRoom", b'{"visibili') self.assertEquals(400, channel.code) - request, channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.render(request) + channel = self.make_request("POST", "/createRoom", b'["hello"]') self.assertEquals(400, channel.code) def test_post_room_invitees_invalid_mxid(self): # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 # Note the trailing space in the MXID here! - request, channel = self.make_request( + channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.render(request) self.assertEquals(400, channel.code) @@ -505,66 +473,54 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self): # missing keys or invalid json - request, channel = self.make_request("PUT", self.path, "{}") - self.render(request) + channel = self.make_request("PUT", self.path, "{}") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.render(request) + channel = self.make_request("PUT", self.path, '{"_name":"bo"}') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, '{"nao') - self.render(request) + channel = self.make_request("PUT", self.path, '{"nao') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( + channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, "text only") - self.render(request) + channel = self.make_request("PUT", self.path, "text only") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, "") - self.render(request) + channel = self.make_request("PUT", self.path, "") self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' - request, channel = self.make_request("PUT", self.path, content) - self.render(request) + channel = self.make_request("PUT", self.path, content) self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_topic(self): # nothing should be there - request, channel = self.make_request("GET", self.path) - self.render(request) + channel = self.make_request("GET", self.path) self.assertEquals(404, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' - request, channel = self.make_request("PUT", self.path, content) - self.render(request) + channel = self.make_request("PUT", self.path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = self.make_request("GET", self.path) - self.render(request) + channel = self.make_request("GET", self.path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' - request, channel = self.make_request("PUT", self.path, content) - self.render(request) + channel = self.make_request("PUT", self.path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = self.make_request("GET", self.path) - self.render(request) + channel = self.make_request("GET", self.path) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -580,30 +536,22 @@ class RoomMemberStateTestCase(RoomBase): def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - request, channel = self.make_request("PUT", path, "{}") - self.render(request) + channel = self.make_request("PUT", path, "{}") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.render(request) + channel = self.make_request("PUT", path, '{"_name":"bo"}') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, '{"nao') - self.render(request) + channel = self.make_request("PUT", path, '{"nao') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( - "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' - ) - self.render(request) + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, "text only") - self.render(request) + channel = self.make_request("PUT", path, "text only") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, "") - self.render(request) + channel = self.make_request("PUT", path, "") self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid keys, wrong types @@ -612,8 +560,7 @@ class RoomMemberStateTestCase(RoomBase): Membership.JOIN, Membership.LEAVE, ) - request, channel = self.make_request("PUT", path, content.encode("ascii")) - self.render(request) + channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_members_self(self): @@ -624,12 +571,10 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - request, channel = self.make_request("PUT", path, content.encode("ascii")) - self.render(request) + channel = self.make_request("PUT", path, content.encode("ascii")) self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", path, None) - self.render(request) + channel = self.make_request("GET", path, None) self.assertEquals(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} @@ -644,12 +589,10 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE - request, channel = self.make_request("PUT", path, content) - self.render(request) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", path, None) - self.render(request) + channel = self.make_request("GET", path, None) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) @@ -665,12 +608,10 @@ class RoomMemberStateTestCase(RoomBase): Membership.INVITE, "Join us!", ) - request, channel = self.make_request("PUT", path, content) - self.render(request) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("GET", path, None) - self.render(request) + channel = self.make_request("GET", path, None) self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) @@ -679,6 +620,7 @@ class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" servlets = [ + admin.register_servlets, profile.register_servlets, room.register_servlets, ] @@ -720,8 +662,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id - request, channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.render(request) + channel = self.make_request("PUT", path, {"displayname": "John Doe"}) self.assertEquals(channel.code, 200, channel.json_body) # Check that all the rooms have been sent a profile update into. @@ -731,8 +672,7 @@ class RoomJoinRatelimitTestCase(RoomBase): self.user_id, ) - request, channel = self.make_request("GET", path) - self.render(request) + channel = self.make_request("GET", path) self.assertEquals(channel.code, 200) self.assertIn("displayname", channel.json_body) @@ -756,10 +696,23 @@ class RoomJoinRatelimitTestCase(RoomBase): # Make sure we send more requests than the rate-limiting config would allow # if all of these requests ended up joining the user to a room. for i in range(4): - request, channel = self.make_request("POST", path % room_id, {}) - self.render(request) + channel = self.make_request("POST", path % room_id, {}) self.assertEquals(channel.code, 200) + @unittest.override_config( + { + "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}, + "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"], + "autocreate_auto_join_rooms": True, + }, + ) + def test_autojoin_rooms(self): + user_id = self.register_user("testuser", "password") + + # Check that the new user successfully joined the four rooms + rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 4) + class RoomMessagesTestCase(RoomBase): """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """ @@ -772,51 +725,40 @@ class RoomMessagesTestCase(RoomBase): def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - request, channel = self.make_request("PUT", path, b"{}") - self.render(request) + channel = self.make_request("PUT", path, b"{}") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.render(request) + channel = self.make_request("PUT", path, b'{"_name":"bo"}') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'{"nao') - self.render(request) + channel = self.make_request("PUT", path, b'{"nao') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request( - "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' - ) - self.render(request) + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b"text only") - self.render(request) + channel = self.make_request("PUT", path, b"text only") self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b"") - self.render(request) + channel = self.make_request("PUT", path, b"") self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' - request, channel = self.make_request("PUT", path, content) - self.render(request) + channel = self.make_request("PUT", path, content) self.assertEquals(400, channel.code, msg=channel.result["body"]) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' - request, channel = self.make_request("PUT", path, content) - self.render(request) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' - request, channel = self.make_request("PUT", path, content) - self.render(request) + channel = self.make_request("PUT", path, content) self.assertEquals(200, channel.code, msg=channel.result["body"]) @@ -830,10 +772,7 @@ class RoomInitialSyncTestCase(RoomBase): self.room_id = self.helper.create_room_as(self.user_id) def test_initial_sync(self): - request, channel = self.make_request( - "GET", "/rooms/%s/initialSync" % self.room_id - ) - self.render(request) + channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) self.assertEquals(200, channel.code) self.assertEquals(self.room_id, channel.json_body["room_id"]) @@ -874,10 +813,9 @@ class RoomMessageListTestCase(RoomBase): def test_topo_token_is_accepted(self): token = "t1-0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.render(request) self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) self.assertEquals(token, channel.json_body["start"]) @@ -886,10 +824,9 @@ class RoomMessageListTestCase(RoomBase): def test_stream_token_is_accepted_for_fwd_pagianation(self): token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.render(request) self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) self.assertEquals(token, channel.json_body["start"]) @@ -920,7 +857,7 @@ class RoomMessageListTestCase(RoomBase): self.helper.send(self.room_id, "message 3") # Check that we get the first and second message when querying /messages. - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" % ( @@ -929,7 +866,6 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) chunk = channel.json_body["chunk"] @@ -949,7 +885,7 @@ class RoomMessageListTestCase(RoomBase): # Check that we only get the second message through /message now that the first # has been purged. - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" % ( @@ -958,7 +894,6 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) chunk = channel.json_body["chunk"] @@ -967,7 +902,7 @@ class RoomMessageListTestCase(RoomBase): # Check that we get no event, but also no error, when querying /messages with # the token that was pointing at the first event, because we don't have it # anymore. - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" % ( @@ -976,7 +911,6 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) chunk = channel.json_body["chunk"] @@ -1027,7 +961,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): self.helper.send(self.room, body="Hi!", tok=self.other_access_token) self.helper.send(self.room, body="There!", tok=self.other_access_token) - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % (self.access_token,), { @@ -1036,7 +970,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): } }, ) - self.render(request) # Check we get the results we expect -- one search result, of the sent # messages @@ -1057,7 +990,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): self.helper.send(self.room, body="Hi!", tok=self.other_access_token) self.helper.send(self.room, body="There!", tok=self.other_access_token) - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % (self.access_token,), { @@ -1070,7 +1003,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): } }, ) - self.render(request) # Check we get the results we expect -- one search result, of the sent # messages @@ -1106,16 +1038,14 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): return self.hs def test_restricted_no_auth(self): - request, channel = self.make_request("GET", self.url) - self.render(request) + channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401, channel.result) def test_restricted_auth(self): self.register_user("user", "pass") tok = self.login("user", "pass") - request, channel = self.make_request("GET", self.url, access_token=tok) - self.render(request) + channel = self.make_request("GET", self.url, access_token=tok) self.assertEqual(channel.code, 200, channel.result) @@ -1143,13 +1073,12 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): self.displayname = "test user" data = {"displayname": self.displayname} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), request_data, access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) @@ -1157,23 +1086,21 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): def test_per_room_profile_forbidden(self): data = {"membership": "join", "displayname": "other test user"} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id), request_data, access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) event_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) res_displayname = channel.json_body["content"]["displayname"] @@ -1202,13 +1129,12 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): def test_join_reason(self): reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/join".format(self.room_id), content={"reason": reason}, access_token=self.second_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) @@ -1217,13 +1143,12 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), content={"reason": reason}, access_token=self.second_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) @@ -1232,13 +1157,12 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) @@ -1247,39 +1171,36 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) def test_unban_reason(self): reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) def test_invite_reason(self): reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) @@ -1293,26 +1214,24 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): ) reason = "hello" - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), content={"reason": reason}, access_token=self.second_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) self._check_for_reason(reason) def _check_for_reason(self, reason): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( self.room_id, self.second_user_id ), access_token=self.creator_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) event_content = channel.json_body @@ -1355,13 +1274,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by a label on a /context request.""" event_id = self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) events_before = channel.json_body["events_before"] @@ -1386,13 +1304,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by the absence of a label on a /context request.""" event_id = self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) events_before = channel.json_body["events_before"] @@ -1422,13 +1339,12 @@ class LabelsTestCase(unittest.HomeserverTestCase): """ event_id = self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/context/%s?filter=%s" % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) events_before = channel.json_body["events_before"] @@ -1451,12 +1367,11 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), ) - self.render(request) events = channel.json_body["chunk"] @@ -1469,12 +1384,11 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), ) - self.render(request) events = channel.json_body["chunk"] @@ -1493,7 +1407,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" % ( @@ -1503,7 +1417,6 @@ class LabelsTestCase(unittest.HomeserverTestCase): json.dumps(self.FILTER_LABELS_NOT_LABELS), ), ) - self.render(request) events = channel.json_body["chunk"] @@ -1525,10 +1438,9 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) - self.render(request) results = channel.json_body["search_categories"]["room_events"]["results"] @@ -1561,10 +1473,9 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) - self.render(request) results = channel.json_body["search_categories"]["room_events"]["results"] @@ -1609,10 +1520,9 @@ class LabelsTestCase(unittest.HomeserverTestCase): self._send_labelled_messages_in_room() - request, channel = self.make_request( + channel = self.make_request( "POST", "/search?access_token=%s" % self.tok, request_data ) - self.render(request) results = channel.json_body["search_categories"]["room_events"]["results"] @@ -1731,13 +1641,12 @@ class ContextTestCase(unittest.HomeserverTestCase): # Check that we can still see the messages before the erasure request. - request, channel = self.make_request( + channel = self.make_request( "GET", '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' % (self.room_id, event_id), access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) events_before = channel.json_body["events_before"] @@ -1796,13 +1705,12 @@ class ContextTestCase(unittest.HomeserverTestCase): # Check that a user that joined the room after the erasure request can't see # the messages anymore. - request, channel = self.make_request( + channel = self.make_request( "GET", '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' % (self.room_id, event_id), access_token=invited_tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) events_before = channel.json_body["events_before"] @@ -1887,13 +1795,12 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases" % (self.room_id,), access_token=access_token, ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) res = channel.json_body self.assertIsInstance(res, dict) @@ -1909,10 +1816,9 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): data = {"room_id": self.room_id} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) @@ -1940,20 +1846,18 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): data = {"room_id": self.room_id} request_data = json.dumps(data) - request, channel = self.make_request( + channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" - request, channel = self.make_request( + channel = self.make_request( "GET", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), access_token=self.room_owner_tok, ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) res = channel.json_body self.assertIsInstance(res, dict) @@ -1961,13 +1865,12 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: """Calls the endpoint under test. returns the json response object.""" - request, channel = self.make_request( + channel = self.make_request( "PUT", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), json.dumps(content), access_token=self.room_owner_tok, ) - self.render(request) self.assertEqual(channel.code, expected_code, channel.result) res = channel.json_body self.assertIsInstance(res, dict) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 94d2bf2eb1..38c51525a3 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py
@@ -39,12 +39,12 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock(), + "red", federation_http_client=None, federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] - hs.get_handlers().federation_handler = Mock() + hs.get_federation_handler = Mock() async def get_user_by_access_token(token=None, allow_guest=False): return { @@ -94,12 +94,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user="@jim:red") def test_set_typing(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.render(request) self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 1) @@ -118,21 +117,19 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): ) def test_set_not_typing(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": false}', ) - self.render(request) self.assertEquals(200, channel.code) def test_typing_timeout(self): - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.render(request) self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 1) @@ -141,12 +138,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 2) - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), b'{"typing": true, "timeout": 30000}', ) - self.render(request) self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index afaf9f7b85..dbc27893b5 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,16 +17,23 @@ # limitations under the License. import json +import re import time +import urllib.parse from typing import Any, Dict, Optional +from mock import patch + import attr from twisted.web.resource import Resource +from twisted.web.server import Site from synapse.api.constants import Membership +from synapse.types import JsonDict -from tests.server import make_request, render +from tests.server import FakeSite, make_request +from tests.test_utils import FakeResponse @attr.s @@ -36,25 +43,51 @@ class RestHelper: """ hs = attr.ib() - resource = attr.ib() + site = attr.ib(type=Site) auth_user_id = attr.ib() def create_room_as( - self, room_creator=None, is_public=True, tok=None, expect_code=200, - ): + self, + room_creator: str = None, + is_public: bool = True, + room_version: str = None, + tok: str = None, + expect_code: int = 200, + ) -> str: + """ + Create a room. + + Args: + room_creator: The user ID to create the room with. + is_public: If True, the `visibility` parameter will be set to the + default (public). Otherwise, the `visibility` parameter will be set + to "private". + room_version: The room version to create the room as. Defaults to Synapse's + default room version. + tok: The access token to use in the request. + expect_code: The expected HTTP response code. + + Returns: + The ID of the newly created room. + """ temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" content = {} if not is_public: content["visibility"] = "private" + if room_version: + content["room_version"] = room_version if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( - self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8") + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + path, + json.dumps(content).encode("utf8"), ) - render(request, self.resource, self.hs.get_reactor()) assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id @@ -124,12 +157,14 @@ class RestHelper: data = {"membership": membership} data.update(extra_data) - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") + channel = make_request( + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(data).encode("utf8"), ) - render(request, self.resource, self.hs.get_reactor()) - assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" % (expect_code, int(channel.result["code"]), channel.result["body"]) @@ -157,10 +192,13 @@ class RestHelper: if tok: path = path + "?access_token=%s" % tok - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8") + channel = make_request( + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(content).encode("utf8"), ) - render(request, self.resource, self.hs.get_reactor()) assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" @@ -210,9 +248,7 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - request, channel = make_request(self.hs.get_reactor(), method, path, content) - - render(request, self.resource, self.hs.get_reactor()) + channel = make_request(self.hs.get_reactor(), self.site, method, path, content) assert int(channel.result["code"]) == expect_code, ( "Expected: %d, got: %d, resp: %r" @@ -295,14 +331,15 @@ class RestHelper: """ image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) - request, channel = make_request( - self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok - ) - request.requestHeaders.addRawHeader( - b"Content-Length", str(image_length).encode("UTF-8") + channel = make_request( + self.hs.get_reactor(), + FakeSite(resource), + "POST", + path, + content=image_data, + access_token=tok, + custom_headers=[(b"Content-Length", str(image_length))], ) - request.render(resource) - self.hs.get_reactor().pump([100]) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, @@ -311,3 +348,111 @@ class RestHelper: ) return channel.json_body + + def login_via_oidc(self, remote_user_id: str) -> JsonDict: + """Log in (as a new user) via OIDC + + Returns the result of the final token login. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + """ + client_redirect_url = "https://x" + + # first hit the redirect url (which will issue a cookie and state) + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + "/login/sso/redirect?redirectUrl=" + client_redirect_url, + ) + # that will redirect to the OIDC IdP, but we skip that and go straight + # back to synapse's OIDC callback resource. However, we do need the "state" + # param that synapse passes to the IdP via query params, and the cookie that + # synapse passes to the client. + assert channel.code == 302 + oauth_uri = channel.headers.getRawHeaders("Location")[0] + params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query) + redirect_uri = "%s?%s" % ( + urllib.parse.urlparse(params["redirect_uri"][0]).path, + urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + ) + cookies = {} + for h in channel.headers.getRawHeaders("Set-Cookie"): + parts = h.split(";") + k, v = parts[0].split("=", maxsplit=1) + cookies[k] = v + + # before we hit the callback uri, stub out some methods in the http client so + # that we don't have to handle full HTTPS requests. + + # (expected url, json response) pairs, in the order we expect them. + expected_requests = [ + # first we get a hit to the token endpoint, which we tell to return + # a dummy OIDC access token + ("https://issuer.test/token", {"access_token": "TEST"}), + # and then one to the user_info endpoint, which returns our remote user id. + ("https://issuer.test/userinfo", {"sub": remote_user_id}), + ] + + async def mock_req(method: str, uri: str, data=None, headers=None): + (expected_uri, resp_obj) = expected_requests.pop(0) + assert uri == expected_uri + resp = FakeResponse( + code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"), + ) + return resp + + with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + # now hit the callback URI with the right params and a made-up code + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + redirect_uri, + custom_headers=[ + ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() + ], + ) + + # expect a confirmation page + assert channel.code == 200 + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.result["body"].decode("utf-8"), + ) + assert m + login_token = m.group(1) + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token and device id. + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + assert channel.code == 200 + return channel.json_body + + +# an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_CONFIG = { + "enabled": True, + "discover": False, + "issuer": "https://issuer.test", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://z", + "token_endpoint": "https://issuer.test/token", + "userinfo_endpoint": "https://issuer.test/userinfo", + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index ae2cd67f35..cb87b80e33 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py
@@ -19,7 +19,6 @@ import os import re from email.parser import Parser from typing import Optional -from urllib.parse import urlencode import pkg_resources @@ -31,6 +30,7 @@ from synapse.rest.client.v2_alpha import account, register from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from tests import unittest +from tests.server import FakeSite, make_request from tests.unittest import override_config @@ -240,12 +240,11 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) def _request_token(self, email, client_secret): - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/password/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) return channel.json_body["sid"] @@ -255,31 +254,29 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): path = link.replace("https://example.com", "") # Load the password reset confirmation page - request, channel = self.make_request("GET", path, shorthand=False) - request.render(self.submit_token_resource) - self.pump() + channel = make_request( + self.reactor, + FakeSite(self.submit_token_resource), + "GET", + path, + shorthand=False, + ) + self.assertEquals(200, channel.code, channel.result) # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button - # Send arguments as url-encoded form data, matching the template's behaviour - form_args = [] - for key, value_list in request.args.items(): - for value in value_list: - arg = (key, value) - form_args.append(arg) - # Confirm the password reset - request, channel = self.make_request( + channel = make_request( + self.reactor, + FakeSite(self.submit_token_resource), "POST", path, - content=urlencode(form_args).encode("utf8"), + content=b"", shorthand=False, content_is_form=True, ) - request.render(self.submit_token_resource) - self.pump() self.assertEquals(200, channel.code, channel.result) def _get_link_from_email(self): @@ -305,7 +302,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): def _reset_password( self, new_password, session_id, client_secret, expected_code=200 ): - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/password", { @@ -319,7 +316,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): }, }, ) - self.render(request) self.assertEquals(expected_code, channel.code, channel.result) @@ -348,11 +344,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) # Check that this access token has been invalidated. - request, channel = self.make_request("GET", "account/whoami") - self.render(request) - self.assertEqual(request.code, 401) + channel = self.make_request("GET", "account/whoami") + self.assertEqual(channel.code, 401) - @unittest.INFO def test_pending_invites(self): """Tests that deactivating a user rejects every pending invite for them.""" store = self.hs.get_datastore() @@ -405,11 +399,10 @@ class DeactivateTestCase(unittest.HomeserverTestCase): "erase": False, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.render(request) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): @@ -529,7 +522,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self._validate_token(link) - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -543,15 +536,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @@ -570,20 +561,18 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/3pid/delete", {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @@ -604,22 +593,20 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/3pid/delete", {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -634,7 +621,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEquals(len(self.email_attempts), 1) # Attempt to add email without clicking the link - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -648,15 +635,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @@ -669,7 +654,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): session_id = "weasle" # Attempt to add email without even requesting an email - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -683,15 +668,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @@ -793,10 +776,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): if next_link: body["next_link"] = next_link - request, channel = self.make_request( - "POST", b"account/3pid/email/requestToken", body, - ) - self.render(request) + channel = self.make_request("POST", b"account/3pid/email/requestToken", body,) self.assertEquals(expect_code, channel.code, channel.result) return channel.json_body.get("sid") @@ -804,12 +784,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def _request_token_invalid_email( self, email, expected_errcode, expected_error, client_secret="foobar", ): - request, channel = self.make_request( + channel = self.make_request( "POST", b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(expected_errcode, channel.json_body["errcode"]) self.assertEqual(expected_error, channel.json_body["error"]) @@ -818,8 +797,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Remove the host path = link.replace("https://example.com", "") - request, channel = self.make_request("GET", path, shorthand=False) - self.render(request) + channel = self.make_request("GET", path, shorthand=False) self.assertEquals(200, channel.code, channel.result) def _get_link_from_email(self): @@ -853,7 +831,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self._validate_token(link) - request, channel = self.make_request( + channel = self.make_request( "POST", b"/_matrix/client/unstable/account/3pid/add", { @@ -868,14 +846,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Get user - request, channel = self.make_request( + channel = self.make_request( "GET", self.url_3pid, access_token=self.user_id_tok, ) - self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 293ccfba2b..ac66a4e0b7 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -12,19 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union + +from typing import Union from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.http.site import SynapseRequest from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import auth, devices, register -from synapse.types import JsonDict +from synapse.rest.oidc import OIDCResource +from synapse.types import JsonDict, UserID from tests import unittest +from tests.rest.client.v1.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel @@ -38,11 +40,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): return succeed(True) -class DummyPasswordChecker(UserInteractiveAuthChecker): - def check_auth(self, authdict, clientip): - return succeed(authdict["identifier"]["user"]) - - class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ @@ -69,12 +66,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase): def register(self, expected_response: int, body: JsonDict) -> FakeChannel: """Make a register request.""" - request, channel = self.make_request( - "POST", "register", body - ) # type: SynapseRequest, FakeChannel - self.render(request) + channel = self.make_request("POST", "register", body) - self.assertEqual(request.code, expected_response) + self.assertEqual(channel.code, expected_response) return channel def recaptcha( @@ -84,27 +78,24 @@ class FallbackAuthTests(unittest.HomeserverTestCase): if post_session is None: post_session = session - request, channel = self.make_request( + channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) # type: SynapseRequest, FakeChannel - self.render(request) - self.assertEqual(request.code, 200) + ) + self.assertEqual(channel.code, 200) - request, channel = self.make_request( + channel = self.make_request( "POST", "auth/m.login.recaptcha/fallback/web?session=" + post_session + "&g-recaptcha-response=a", ) - self.render(request) - self.assertEqual(request.code, expected_post_response) + self.assertEqual(channel.code, expected_post_response) # The recaptcha handler is called with the response given attempts = self.recaptcha_checker.recaptcha_attempts self.assertEqual(len(attempts), 1) self.assertEqual(attempts[0][0]["response"], "a") - @unittest.INFO def test_fallback_captcha(self): """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec @@ -165,36 +156,44 @@ class UIAuthTests(unittest.HomeserverTestCase): register.register_servlets, ] - def prepare(self, reactor, clock, hs): - auth_handler = hs.get_auth_handler() - auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs) + def default_config(self): + config = super().default_config() - self.user_pass = "pass" - self.user = self.register_user("test", self.user_pass) - self.user_tok = self.login("test", self.user_pass) + # we enable OIDC as a way of testing SSO flows + oidc_config = {} + oidc_config.update(TEST_OIDC_CONFIG) + oidc_config["allow_existing_users"] = True - def get_device_ids(self) -> List[str]: - # Get the list of devices so one can be deleted. - request, channel = self.make_request( - "GET", "devices", access_token=self.user_tok, - ) # type: SynapseRequest, FakeChannel - self.render(request) + config["oidc_config"] = oidc_config + config["public_baseurl"] = "https://synapse.test" + return config - # Get the ID of the device. - self.assertEqual(request.code, 200) - return [d["device_id"] for d in channel.json_body["devices"]] + def create_resource_dict(self): + resource_dict = super().create_resource_dict() + # mount the OIDC resource at /_synapse/oidc + resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) + return resource_dict + + def prepare(self, reactor, clock, hs): + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + self.device_id = "dev1" + self.user_tok = self.login("test", self.user_pass, self.device_id) def delete_device( - self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b"" + self, + access_token: str, + device: str, + expected_response: int, + body: Union[bytes, JsonDict] = b"", ) -> FakeChannel: """Delete an individual device.""" - request, channel = self.make_request( - "DELETE", "devices/" + device, body, access_token=self.user_tok - ) # type: SynapseRequest, FakeChannel - self.render(request) + channel = self.make_request( + "DELETE", "devices/" + device, body, access_token=access_token, + ) # Ensure the response is sane. - self.assertEqual(request.code, expected_response) + self.assertEqual(channel.code, expected_response) return channel @@ -202,13 +201,12 @@ class UIAuthTests(unittest.HomeserverTestCase): """Delete 1 or more devices.""" # Note that this uses the delete_devices endpoint so that we can modify # the payload half-way through some tests. - request, channel = self.make_request( + channel = self.make_request( "POST", "delete_devices", body, access_token=self.user_tok, - ) # type: SynapseRequest, FakeChannel - self.render(request) + ) # Ensure the response is sane. - self.assertEqual(request.code, expected_response) + self.assertEqual(channel.code, expected_response) return channel @@ -216,11 +214,9 @@ class UIAuthTests(unittest.HomeserverTestCase): """ Test user interactive authentication outside of registration. """ - device_id = self.get_device_ids()[0] - # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(device_id, 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) # Grab the session session = channel.json_body["session"] @@ -229,7 +225,8 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow. self.delete_device( - device_id, + self.user_tok, + self.device_id, 200, { "auth": { @@ -241,6 +238,31 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) + def test_grandfathered_identifier(self): + """Check behaviour without "identifier" dict + + Synapse used to require clients to submit a "user" field for m.login.password + UIA - check that still works. + """ + + channel = self.delete_device(self.user_tok, self.device_id, 401) + session = channel.json_body["session"] + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + self.device_id, + 200, + { + "auth": { + "type": "m.login.password", + "user": self.user, + "password": self.user_pass, + "session": session, + }, + }, + ) + def test_can_change_body(self): """ The client dict can be modified during the user interactive authentication session. @@ -252,14 +274,11 @@ class UIAuthTests(unittest.HomeserverTestCase): session ID should be rejected. """ # Create a second login. - self.login("test", self.user_pass) - - device_ids = self.get_device_ids() - self.assertEqual(len(device_ids), 2) + self.login("test", self.user_pass, "dev2") # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [device_ids[0]]}) + channel = self.delete_devices(401, {"devices": [self.device_id]}) # Grab the session session = channel.json_body["session"] @@ -271,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_devices( 200, { - "devices": [device_ids[1]], + "devices": ["dev2"], "auth": { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": self.user}, @@ -286,14 +305,11 @@ class UIAuthTests(unittest.HomeserverTestCase): The initial requested URI cannot be modified during the user interactive authentication session. """ # Create a second login. - self.login("test", self.user_pass) - - device_ids = self.get_device_ids() - self.assertEqual(len(device_ids), 2) + self.login("test", self.user_pass, "dev2") # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) # Grab the session session = channel.json_body["session"] @@ -302,8 +318,11 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. This results in an error. + # + # This makes use of the fact that the device ID is embedded into the URL. self.delete_device( - device_ids[1], + self.user_tok, + "dev2", 403, { "auth": { @@ -314,3 +333,83 @@ class UIAuthTests(unittest.HomeserverTestCase): }, }, ) + + @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}}) + def test_can_reuse_session(self): + """ + The session can be reused if configured. + + Compare to test_cannot_change_uri. + """ + # Create a second and third login. + self.login("test", self.user_pass, "dev2") + self.login("test", self.user_pass, "dev3") + + # Attempt to delete a device. This works since the user just logged in. + self.delete_device(self.user_tok, "dev2", 200) + + # Move the clock forward past the validation timeout. + self.reactor.advance(6) + + # Deleting another devices throws the user into UI auth. + channel = self.delete_device(self.user_tok, "dev3", 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + "dev3", + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + # Make another request, but try to delete the first device. This works + # due to re-using the previous session. + # + # Note that *no auth* information is provided, not even a session iD! + self.delete_device(self.user_tok, self.device_id, 200) + + def test_does_not_offer_password_for_sso_user(self): + login_resp = self.helper.login_via_oidc("username") + user_tok = login_resp["access_token"] + device_id = login_resp["device_id"] + + # now call the device deletion API: we should get the option to auth with SSO + # and not password. + channel = self.delete_device(user_tok, device_id, 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) + + def test_does_not_offer_sso_for_password_user(self): + # now call the device deletion API: we should get the option to auth with SSO + # and not password. + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.password"]}]) + + def test_offers_both_flows_for_upgraded_user(self): + """A user that had a password and then logged in with SSO should get both flows + """ + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + # we have no particular expectations of ordering here + self.assertIn({"stages": ["m.login.password"]}, flows) + self.assertIn({"stages": ["m.login.sso"]}, flows) + self.assertEqual(len(flows), 2) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index b9e01c9418..e808339fb3 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -36,8 +36,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): return hs def test_check_auth_required(self): - request, channel = self.make_request("GET", self.url) - self.render(request) + channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401) @@ -45,8 +44,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.register_user("user", "pass") access_token = self.login("user", "pass") - request, channel = self.make_request("GET", self.url, access_token=access_token) - self.render(request) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) @@ -64,8 +62,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): user = self.register_user(localpart, password) access_token = self.login(user, password) - request, channel = self.make_request("GET", self.url, access_token=access_token) - self.render(request) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) @@ -73,8 +70,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): # Test case where password is handled outside of Synapse self.assertTrue(capabilities["m.change_password"]["enabled"]) self.get_success(self.store.user_set_password_hash(user, None)) - request, channel = self.make_request("GET", self.url, access_token=access_token) - self.render(request) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index de00350580..f761c44936 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -36,12 +36,11 @@ class FilterTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastore() def test_add_filter(self): - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - self.render(request) self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.json_body, {"filter_id": "0"}) @@ -50,12 +49,11 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEquals(filter.result, self.EXAMPLE_FILTER) def test_add_filter_for_other_user(self): - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), self.EXAMPLE_FILTER_JSON, ) - self.render(request) self.assertEqual(channel.result["code"], b"403") self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) @@ -63,12 +61,11 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - self.render(request) self.hs.is_mine = _is_mine self.assertEqual(channel.result["code"], b"403") @@ -82,19 +79,17 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.reactor.advance(1) filter_id = filter_id.result - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) ) - self.render(request) self.assertEqual(channel.result["code"], b"200") self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) ) - self.render(request) self.assertEqual(channel.result["code"], b"404") self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) @@ -102,18 +97,16 @@ class FilterTestCase(unittest.HomeserverTestCase): # Currently invalid params do not have an appropriate errcode # in errors.py def test_get_filter_invalid_id(self): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) ) - self.render(request) self.assertEqual(channel.result["code"], b"400") # No ID also returns an invalid_id error def test_get_filter_no_id(self): - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) ) - self.render(request) self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index c57072f50c..fba34def30 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -70,10 +70,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_get_policy(self): """Tests if the /password_policy endpoint returns the configured policy.""" - request, channel = self.make_request( - "GET", "/_matrix/client/r0/password_policy" - ) - self.render(request) + channel = self.make_request("GET", "/_matrix/client/r0/password_policy") self.assertEqual(channel.code, 200, channel.result) self.assertEqual( @@ -90,8 +87,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_too_short(self): request_data = json.dumps({"username": "kermit", "password": "shorty"}) - request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -100,8 +96,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_digit(self): request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) - request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -110,8 +105,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_symbol(self): request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) - request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -120,8 +114,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_uppercase(self): request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) - request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -130,8 +123,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_no_lowercase(self): request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) - request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) + channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -140,8 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): def test_password_compliant(self): request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) - request, channel = self.make_request("POST", self.register_url, request_data) - self.render(request) + channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has # responded with a list of registration flows. @@ -167,13 +158,12 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): }, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/account/password", request_data, access_token=tok, ) - self.render(request) self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2fc3a60fc5..27db4f551e 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -55,15 +55,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", ) self.hs.get_datastore().services_cache.append(appservice) request_data = json.dumps({"username": "as_user_kermit"}) - request, channel = self.make_request( + channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} @@ -72,25 +72,22 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_appservice_registration_invalid(self): self.appservice = None # no application service exists request_data = json.dumps({"username": "kermit"}) - request, channel = self.make_request( + channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.render(request) self.assertEquals(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self): request_data = json.dumps({"username": "kermit", "password": 666}) - request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self): request_data = json.dumps({"username": 777, "password": "monkey"}) - request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid username") @@ -105,8 +102,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "auth": {"type": LoginType.DUMMY}, } request_data = json.dumps(params) - request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) + channel = self.make_request(b"POST", self.url, request_data) det_data = { "user_id": user_id, @@ -121,18 +117,17 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) - request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Registration has been disabled") + self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") def test_POST_guest_registration(self): self.hs.config.macaroon_secret_key = "test" self.hs.config.allow_guest_access = True - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.render(request) + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} self.assertEquals(channel.result["code"], b"200", channel.result) @@ -141,8 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_disabled_guest_registration(self): self.hs.config.allow_guest_access = False - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.render(request) + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") @@ -151,8 +145,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_ratelimiting_guest(self): for i in range(0, 6): url = self.url + b"?kind=guest" - request, channel = self.make_request(b"POST", url, b"{}") - self.render(request) + channel = self.make_request(b"POST", url, b"{}") if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -162,8 +155,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.render(request) + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.assertEquals(channel.result["code"], b"200", channel.result) @@ -177,8 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "auth": {"type": LoginType.DUMMY}, } request_data = json.dumps(params) - request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) + channel = self.make_request(b"POST", self.url, request_data) if i == 5: self.assertEquals(channel.result["code"], b"429", channel.result) @@ -188,14 +179,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.render(request) + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.assertEquals(channel.result["code"], b"200", channel.result) def test_advertised_flows(self): - request, channel = self.make_request(b"POST", self.url, b"{}") - self.render(request) + channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -218,8 +207,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } ) def test_advertised_flows_captcha_and_terms_and_3pids(self): - request, channel = self.make_request(b"POST", self.url, b"{}") - self.render(request) + channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -251,8 +239,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } ) def test_advertised_flows_no_msisdn_email_required(self): - request, channel = self.make_request(b"POST", self.url, b"{}") - self.render(request) + channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] @@ -292,12 +279,11 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) ) - request, channel = self.make_request( + channel = self.make_request( "POST", b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertIsNotNone(channel.json_body.get("sid")) @@ -332,15 +318,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request(b"GET", "/sync", access_token=tok) - self.render(request) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - request, channel = self.make_request(b"GET", "/sync", access_token=tok) - self.render(request) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( @@ -359,19 +343,15 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): self.register_user("admin", "adminpassword", admin=True) admin_tok = self.login("admin", "adminpassword") - url = "/_matrix/client/unstable/admin/account_validity/validity" + url = "/_synapse/admin/v1/account_validity/validity" params = {"user_id": user_id} request_data = json.dumps(params) - request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok - ) - self.render(request) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request(b"GET", "/sync", access_token=tok) - self.render(request) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) def test_manual_expire(self): @@ -381,23 +361,19 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): self.register_user("admin", "adminpassword", admin=True) admin_tok = self.login("admin", "adminpassword") - url = "/_matrix/client/unstable/admin/account_validity/validity" + url = "/_synapse/admin/v1/account_validity/validity" params = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } request_data = json.dumps(params) - request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok - ) - self.render(request) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request(b"GET", "/sync", access_token=tok) - self.render(request) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result @@ -410,30 +386,25 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): self.register_user("admin", "adminpassword", admin=True) admin_tok = self.login("admin", "adminpassword") - url = "/_matrix/client/unstable/admin/account_validity/validity" + url = "/_synapse/admin/v1/account_validity/validity" params = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } request_data = json.dumps(params) - request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok - ) - self.render(request) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEquals(channel.result["code"], b"200", channel.result) # Try to log the user out - request, channel = self.make_request(b"POST", "/logout", access_token=tok) - self.render(request) + channel = self.make_request(b"POST", "/logout", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions - request, channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.render(request) + channel = self.make_request(b"POST", "/logout/all", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -507,8 +478,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # retrieve the token from the DB. renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token - request, channel = self.make_request(b"GET", url) - self.render(request) + channel = self.make_request(b"GET", url) self.assertEquals(channel.result["code"], b"200", channel.result) # Check that we're getting HTML back. @@ -528,16 +498,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # our access token should be denied from now, otherwise they should # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) - request, channel = self.make_request(b"GET", "/sync", access_token=tok) - self.render(request) + channel = self.make_request(b"GET", "/sync", access_token=tok) self.assertEquals(channel.result["code"], b"200", channel.result) def test_renewal_invalid_token(self): # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" - request, channel = self.make_request(b"GET", url) - self.render(request) + channel = self.make_request(b"GET", url) self.assertEquals(channel.result["code"], b"404", channel.result) # Check that we're getting HTML back. @@ -558,12 +526,11 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.email_attempts = [] (user_id, tok) = self.create_user() - request, channel = self.make_request( + channel = self.make_request( b"POST", "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -583,11 +550,10 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "erase": False, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.render(request) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.reactor.advance(datetime.timedelta(days=8).total_seconds()) @@ -598,7 +564,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): tok = self.login("kermit", "monkey") # We need to manually add an email address otherwise the handler will do # nothing. - now = self.hs.clock.time_msec() + now = self.hs.get_clock().time_msec() self.get_success( self.store.user_add_threepid( user_id=user_id, @@ -616,7 +582,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # We need to manually add an email address otherwise the handler will do # nothing. - now = self.hs.clock.time_msec() + now = self.hs.get_clock().time_msec() self.get_success( self.store.user_add_threepid( user_id=user_id, @@ -635,12 +601,11 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.email_attempts = [] # Test that we're still able to manually trigger a mail to be sent. - request, channel = self.make_request( + channel = self.make_request( b"POST", "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -676,7 +641,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): self.hs.config.account_validity.startup_job_max_delta = self.max_delta - now_ms = self.hs.clock.time_msec() + now_ms = self.hs.get_clock().time_msec() self.get_success(self.store._set_expiration_date_when_missing()) res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 99c9f4e928..bd574077e7 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -60,12 +60,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): event_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, event_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assert_dict( @@ -108,13 +107,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the full @@ -154,13 +152,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): if prev_token: from_token = "&from=" + prev_token - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" % (self.room, self.parent_id, from_token), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) @@ -213,13 +210,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): if prev_token: from_token = "&from=" + prev_token - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" % (self.room, self.parent_id, from_token), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -283,7 +279,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): if prev_token: from_token = "&from=" + prev_token - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s" "/aggregations/%s/%s/m.reaction/%s?limit=1%s" @@ -296,7 +292,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): ), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) @@ -330,13 +325,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( @@ -363,22 +357,20 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Now lets redact one of the 'a' reactions - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), access_token=self.user_token, content={}, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( @@ -390,13 +382,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): """Test that aggregations must be annotations. """ - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" % (self.room, self.parent_id, RelationTypes.REPLACE), access_token=self.user_token, ) - self.render(request) self.assertEquals(400, channel.code, channel.json_body) def test_aggregation_get_event(self): @@ -423,12 +414,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( @@ -460,12 +450,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): edit_event_id = channel.json_body["event_id"] - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(channel.json_body["content"], new_body) @@ -518,12 +507,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(200, channel.code, channel.json_body) - request, channel = self.make_request( + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(channel.json_body["content"], new_body) @@ -561,37 +549,34 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Check the relation is returned - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" % (self.room, original_event_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) self.assertEquals(len(channel.json_body["chunk"]), 1) # Redact the original event - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/redact/%s/%s" % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), access_token=self.user_token, content="{}", ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) # Try to check for remaining m.replace relations - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" % (self.room, original_event_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) # Check that no relations are returned @@ -613,7 +598,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) # Redact the original - request, channel = self.make_request( + channel = self.make_request( "PUT", "/rooms/%s/redact/%s/%s" % ( @@ -624,17 +609,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, content="{}", ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) # Check that aggregations returns zero - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" % (self.room, original_event_id), access_token=self.user_token, ) - self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) @@ -673,14 +656,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): original_id = parent_id if parent_id else self.parent_id - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" % (self.room, original_id, relation_type, event_type, query), json.dumps(content).encode("utf-8"), access_token=access_token, ) - self.render(request) return channel def _create_user(self, localpart): diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
index 5ae72fd008..116ace1812 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -17,6 +17,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import shared_rooms from tests import unittest +from tests.server import FakeChannel class UserSharedRoomsTest(unittest.HomeserverTestCase): @@ -40,15 +41,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.store = hs.get_datastore() self.handler = hs.get_user_directory_handler() - def _get_shared_rooms(self, token, other_user): - request, channel = self.make_request( + def _get_shared_rooms(self, token, other_user) -> FakeChannel: + return self.make_request( "GET", "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" % other_user, access_token=token, ) - self.render(request) - return request, channel def test_shared_room_list_public(self): """ @@ -64,7 +63,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.invite(room, src=u1, targ=u2, tok=u1_token) self.helper.join(room, user=u2, tok=u2_token) - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 1) self.assertEquals(channel.json_body["joined"][0], room) @@ -83,7 +82,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.invite(room, src=u1, targ=u2, tok=u1_token) self.helper.join(room, user=u2, tok=u2_token) - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 1) self.assertEquals(channel.json_body["joined"][0], room) @@ -105,7 +104,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.join(room_public, user=u2, tok=u2_token) self.helper.join(room_private, user=u1, tok=u1_token) - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 2) self.assertTrue(room_public in channel.json_body["joined"]) @@ -126,13 +125,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.join(room, user=u2, tok=u2_token) # Assert user directory is not empty - request, channel = self._get_shared_rooms(u1_token, u2) + channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 1) self.assertEquals(channel.json_body["joined"][0], room) self.helper.leave(room, user=u1, tok=u1_token) - request, channel = self._get_shared_rooms(u2_token, u1) + channel = self._get_shared_rooms(u2_token, u1) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index a31e44c97e..512e36c236 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -35,8 +35,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ] def test_sync_argless(self): - request, channel = self.make_request("GET", "/sync") - self.render(request) + channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) self.assertTrue( @@ -56,8 +55,7 @@ class FilterTestCase(unittest.HomeserverTestCase): """ self.hs.config.use_presence = False - request, channel = self.make_request("GET", "/sync") - self.render(request) + channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) self.assertTrue( @@ -196,10 +194,9 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): tok=tok, ) - request, channel = self.make_request( + channel = self.make_request( "GET", "/sync?filter=%s" % sync_filter, access_token=tok ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] @@ -248,44 +245,35 @@ class SyncTypingTests(unittest.HomeserverTestCase): self.helper.send(room, body="There!", tok=other_access_token) # Start typing. - request, channel = self.make_request( + channel = self.make_request( "PUT", typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.render(request) self.assertEquals(200, channel.code) - request, channel = self.make_request( - "GET", "/sync?access_token=%s" % (access_token,) - ) - self.render(request) + channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] # Stop typing. - request, channel = self.make_request( + channel = self.make_request( "PUT", typing_url % (room, other_user_id, other_access_token), b'{"typing": false}', ) - self.render(request) self.assertEquals(200, channel.code) # Start typing. - request, channel = self.make_request( + channel = self.make_request( "PUT", typing_url % (room, other_user_id, other_access_token), b'{"typing": true, "timeout": 30000}', ) - self.render(request) self.assertEquals(200, channel.code) # Should return immediately - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) - self.render(request) + channel = self.make_request("GET", sync_url % (access_token, next_batch)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -297,10 +285,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): # invalidate the stream token. self.helper.send(room, body="There!", tok=other_access_token) - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) - self.render(request) + channel = self.make_request("GET", sync_url % (access_token, next_batch)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -308,10 +293,7 @@ class SyncTypingTests(unittest.HomeserverTestCase): # ahead, and therefore it's saying the typing (that we've actually # already seen) is new, since it's got a token above our new, now-reset # stream token. - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) - self.render(request) + channel = self.make_request("GET", sync_url % (access_token, next_batch)) self.assertEquals(200, channel.code) next_batch = channel.json_body["next_batch"] @@ -320,10 +302,8 @@ class SyncTypingTests(unittest.HomeserverTestCase): typing._reset() # Now it SHOULD fail as it never completes! - request, channel = self.make_request( - "GET", sync_url % (access_token, next_batch) - ) - self.assertRaises(TimedOutException, self.render, request) + with self.assertRaises(TimedOutException): + self.make_request("GET", sync_url % (access_token, next_batch)) class UnreadMessagesTestCase(unittest.HomeserverTestCase): @@ -395,13 +375,12 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Send a read receipt to tell the server we've read the latest event. body = json.dumps({"m.read": res["event_id"]}).encode("utf8") - request, channel = self.make_request( + channel = self.make_request( "POST", "/rooms/%s/read_markers" % self.room_id, body, access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) # Check that the unread counter is back to 0. @@ -463,10 +442,9 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): def _check_unread_count(self, expected_count: True): """Syncs and compares the unread count with the expected value.""" - request, channel = self.make_request( + channel = self.make_request( "GET", self.url % self.next_batch, access_token=self.tok, ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 6850c666be..5e90d656f7 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -32,16 +32,16 @@ from synapse.util.httpresourcetree import create_resource_tree from synapse.util.stringutils import random_string from tests import unittest -from tests.server import FakeChannel, wait_until_result +from tests.server import FakeChannel from tests.utils import default_config class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.http_client = Mock() - return self.setup_test_homeserver(http_client=self.http_client) + return self.setup_test_homeserver(federation_http_client=self.http_client) - def create_test_json_resource(self): + def create_test_resource(self): return create_resource_tree( {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() ) @@ -94,7 +94,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): % (server_name.encode("utf-8"), key_id.encode("utf-8")), b"1.1", ) - wait_until_result(self.reactor, req) + channel.await_result() self.assertEqual(channel.code, 200) resp = channel.json_body return resp @@ -172,7 +172,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): } ] self.hs2 = self.setup_test_homeserver( - http_client=self.http_client2, config=config + federation_http_client=self.http_client2, config=config ) # wire up outbound POST /key/v2/query requests from hs2 so that they @@ -190,7 +190,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): req.requestReceived( b"POST", path.encode("utf-8"), b"1.1", ) - wait_until_result(self.reactor, req) + channel.await_result() self.assertEqual(channel.code, 200) resp = channel.json_body return resp diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 5f897d49cf..ae2b32b131 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py
@@ -36,6 +36,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from tests import unittest +from tests.server import FakeSite, make_request class MediaStorageTests(unittest.HomeserverTestCase): @@ -213,7 +214,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): } config["media_storage_providers"] = [provider_config] - hs = self.setup_test_homeserver(config=config, http_client=client) + hs = self.setup_test_homeserver(config=config, federation_http_client=client) return hs @@ -227,8 +228,14 @@ class MediaRepoTests(unittest.HomeserverTestCase): def _req(self, content_disposition): - request, channel = self.make_request("GET", self.media_id, shorthand=False) - request.render(self.download_resource) + channel = make_request( + self.reactor, + FakeSite(self.download_resource), + "GET", + self.media_id, + shorthand=False, + await_result=False, + ) self.pump() # We've made one fetch, to example.com, using the media URL, and asking @@ -317,10 +324,14 @@ class MediaRepoTests(unittest.HomeserverTestCase): def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method - request, channel = self.make_request( - "GET", self.media_id + params, shorthand=False + channel = make_request( + self.reactor, + FakeSite(self.thumbnail_resource), + "GET", + self.media_id + params, + shorthand=False, + await_result=False, ) - request.render(self.thumbnail_resource) self.pump() headers = { @@ -348,7 +359,19 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel.json_body, { "errcode": "M_NOT_FOUND", - "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']" - % method, + "error": "Not found [b'example.com', b'12345']", }, ) + + def test_x_robots_tag_header(self): + """ + Tests that the `X-Robots-Tag` header is present, which informs web crawlers + to not index, archive, or follow links in media. + """ + channel = self._req(b"inline; filename=out" + self.test_image.extension) + + headers = channel.headers + self.assertEqual( + headers.getRawHeaders(b"X-Robots-Tag"), + [b"noindex, nofollow, noarchive, noimageindex"], + ) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index c00a7b9114..83d728b4a4 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py
@@ -18,41 +18,15 @@ import re from mock import patch -import attr - from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.error import DNSLookupError -from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol -from twisted.web._newclient import ResponseDone from tests import unittest from tests.server import FakeTransport -@attr.s -class FakeResponse: - version = attr.ib() - code = attr.ib() - phrase = attr.ib() - headers = attr.ib() - body = attr.ib() - absoluteURI = attr.ib() - - @property - def request(self): - @attr.s - class FakeTransport: - absoluteURI = self.absoluteURI - - return FakeTransport() - - def deliverBody(self, protocol): - protocol.dataReceived(self.body) - protocol.connectionLost(Failure(ResponseDone())) - - class URLPreviewTests(unittest.HomeserverTestCase): hijack_auth = True @@ -133,13 +107,18 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.reactor.nameResolver = Resolver() + def create_test_resource(self): + return self.hs.get_media_repository_resource() + def test_cache_returns_correct_type(self): self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -159,11 +138,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) # Check the cache returns the correct response - request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + channel = self.make_request( + "GET", "preview_url?url=http://matrix.org", shorthand=False ) - request.render(self.preview_url) - self.pump() # Check the cache response has the same content self.assertEqual(channel.code, 200) @@ -177,11 +154,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertNotIn("http://matrix.org", self.preview_url._cache) # Check the database cache returns the correct response - request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + channel = self.make_request( + "GET", "preview_url?url=http://matrix.org", shorthand=False ) - request.render(self.preview_url) - self.pump() # Check the cache response has the same content self.assertEqual(channel.code, 200) @@ -200,10 +175,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"</head></html>" ) - request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -233,10 +210,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"</head></html>" ) - request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -266,10 +245,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"</head></html>" ) - request, channel = self.make_request( - "GET", "url_preview?url=http://matrix.org", shorthand=False + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -297,10 +278,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] - request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + channel = self.make_request( + "GET", + "preview_url?url=http://example.com", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -325,11 +308,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] - request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + channel = self.make_request( + "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) @@ -348,11 +329,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] - request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + channel = self.make_request( + "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 502) self.assertEqual( @@ -367,11 +346,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ Blacklisted IP addresses, accessed directly, are not spidered. """ - request, channel = self.make_request( - "GET", "url_preview?url=http://192.168.1.1", shorthand=False + channel = self.make_request( + "GET", "preview_url?url=http://192.168.1.1", shorthand=False ) - request.render(self.preview_url) - self.pump() # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) @@ -388,11 +365,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ Blacklisted IP ranges, accessed directly, are not spidered. """ - request, channel = self.make_request( - "GET", "url_preview?url=http://1.1.1.2", shorthand=False + channel = self.make_request( + "GET", "preview_url?url=http://1.1.1.2", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 403) self.assertEqual( @@ -410,10 +385,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")] - request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + channel = self.make_request( + "GET", + "preview_url?url=http://example.com", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -445,11 +422,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): (IPv4Address, "10.1.2.3"), ] - request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + channel = self.make_request( + "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, @@ -467,11 +442,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") ] - request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + channel = self.make_request( + "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) @@ -490,11 +463,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] - request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + channel = self.make_request( + "GET", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 502) self.assertEqual( @@ -509,11 +480,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): """ OPTIONS returns the OPTIONS. """ - request, channel = self.make_request( - "OPTIONS", "url_preview?url=http://example.com", shorthand=False + channel = self.make_request( + "OPTIONS", "preview_url?url=http://example.com", shorthand=False ) - request.render(self.preview_url) - self.pump() self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {}) @@ -524,10 +493,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] # Build and make a request to the server - request, channel = self.make_request( - "GET", "url_preview?url=http://example.com", shorthand=False + channel = self.make_request( + "GET", + "preview_url?url=http://example.com", + shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() # Extract Synapse's tcp client @@ -596,12 +567,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"</head></html>" ) - request, channel = self.make_request( + channel = self.make_request( "GET", - "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + "preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) @@ -661,12 +632,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): } end_content = json.dumps(result).encode("utf-8") - request, channel = self.make_request( + channel = self.make_request( "GET", - "url_preview?url=http://twitter.com/matrixdotorg/status/12345", + "preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, + await_result=False, ) - request.render(self.preview_url) self.pump() client = self.reactor.tcpClients[0][2].buildProtocol(None) diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 2d021f6565..32acd93dc1 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py
@@ -20,15 +20,12 @@ from tests import unittest class HealthCheckTests(unittest.HomeserverTestCase): - def setUp(self): - super().setUp() - + def create_test_resource(self): # replace the JsonResource with a HealthResource. - self.resource = HealthResource() + return HealthResource() def test_health(self): - request, channel = self.make_request("GET", "/health", shorthand=False) - self.render(request) + channel = self.make_request("GET", "/health", shorthand=False) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index dcd65c2a50..14de0921be 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py
@@ -20,22 +20,19 @@ from tests import unittest class WellKnownTests(unittest.HomeserverTestCase): - def setUp(self): - super().setUp() - + def create_test_resource(self): # replace the JsonResource with a WellKnownResource - self.resource = WellKnownResource(self.hs) + return WellKnownResource(self.hs) def test_well_known(self): self.hs.config.public_baseurl = "https://tesths" self.hs.config.default_identity_server = "https://testis" - request, channel = self.make_request( + channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.render(request) - self.assertEqual(request.code, 200) + self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body, { @@ -47,9 +44,8 @@ class WellKnownTests(unittest.HomeserverTestCase): def test_well_known_no_public_baseurl(self): self.hs.config.public_baseurl = None - request, channel = self.make_request( + channel = self.make_request( "GET", "/.well-known/matrix/client", shorthand=False ) - self.render(request) - self.assertEqual(request.code, 404) + self.assertEqual(channel.code, 404) diff --git a/tests/server.py b/tests/server.py
index b404ad4e2a..7d1ad362c4 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -1,8 +1,11 @@ import json import logging +from collections import deque from io import SEEK_END, BytesIO +from typing import Callable, Iterable, Optional, Tuple, Union import attr +from typing_extensions import Deque from zope.interface import implementer from twisted.internet import address, threads, udp @@ -16,8 +19,8 @@ from twisted.internet.interfaces import ( ) from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock -from twisted.web.http import unquote from twisted.web.http_headers import Headers +from twisted.web.resource import IResource from twisted.web.server import Site from synapse.http.site import SynapseRequest @@ -43,7 +46,7 @@ class FakeChannel: site = attr.ib(type=Site) _reactor = attr.ib() - result = attr.ib(default=attr.Factory(dict)) + result = attr.ib(type=dict, default=attr.Factory(dict)) _producer = None @property @@ -114,6 +117,25 @@ class FakeChannel: def transport(self): return self + def await_result(self, timeout: int = 100) -> None: + """ + Wait until the request is finished. + """ + self._reactor.run() + x = 0 + + while not self.result.get("done"): + # If there's a producer, tell it to resume producing so we get content + if self._producer: + self._producer.resumeProducing() + + x += 1 + + if x > timeout: + raise TimedOutException("Timed out waiting for request to finish.") + + self._reactor.advance(0.1) + class FakeSite: """ @@ -125,9 +147,21 @@ class FakeSite: site_tag = "test" access_logger = logging.getLogger("synapse.access.http.fake") + def __init__(self, resource: IResource): + """ + + Args: + resource: the resource to be used for rendering all requests + """ + self._resource = resource + + def getResourceFor(self, request): + return self._resource + def make_request( reactor, + site: Site, method, path, content=b"", @@ -136,12 +170,19 @@ def make_request( shorthand=True, federation_auth_origin=None, content_is_form=False, -): + await_result: bool = True, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, +) -> FakeChannel: """ - Make a web request using the given method and path, feed it the - content, and return the Request and the Channel underneath. + Make a web request using the given method, path and content, and render it + + Returns the fake Channel object which records the response to the request. Args: + site: The twisted Site to use to render the request + method (bytes/unicode): The HTTP request method ("verb"). path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such). @@ -154,8 +195,14 @@ def make_request( content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. + custom_headers: (name, value) pairs to add as request headers + + await_result: whether to wait for the request to complete rendering. If true, + will pump the reactor until the the renderer tells the channel the request + is finished. + Returns: - Tuple[synapse.http.site.SynapseRequest, channel] + channel """ if not isinstance(method, bytes): method = method.encode("ascii") @@ -169,24 +216,24 @@ def make_request( and not path.startswith(b"/_matrix") and not path.startswith(b"/_synapse") ): + if path.startswith(b"/"): + path = path[1:] path = b"/_matrix/client/r0/" + path - path = path.replace(b"//", b"/") if not path.startswith(b"/"): path = b"/" + path + if isinstance(content, dict): + content = json.dumps(content).encode("utf8") if isinstance(content, str): content = content.encode("utf8") - site = FakeSite() channel = FakeChannel(site, reactor) req = request(channel) - req.process = lambda: b"" req.content = BytesIO(content) # Twisted expects to be at the end of the content when parsing the request. req.content.seek(SEEK_END) - req.postpath = list(map(unquote, path[1:].split(b"/"))) if access_token: req.requestHeaders.addRawHeader( @@ -208,35 +255,17 @@ def make_request( # Assume the body is JSON req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") - req.requestReceived(method, path, b"1.1") + if custom_headers: + for k, v in custom_headers: + req.requestHeaders.addRawHeader(k, v) - return req, channel - - -def wait_until_result(clock, request, timeout=100): - """ - Wait until the request is finished. - """ - clock.run() - x = 0 - - while not request.finished: - - # If there's a producer, tell it to resume producing so we get content - if request._channel._producer: - request._channel._producer.resumeProducing() - - x += 1 - - if x > timeout: - raise TimedOutException("Timed out waiting for request to finish.") - - clock.advance(0.1) + req.parseCookies() + req.requestReceived(method, path, b"1.1") + if await_result: + channel.await_result() -def render(request, resource, clock): - request.render(resource) - wait_until_result(clock, request) + return channel @implementer(IReactorPluggableNameResolver) @@ -251,6 +280,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self._tcp_callbacks = {} self._udp = [] lookups = self.lookups = {} + self._thread_callbacks = deque() # type: Deque[Callable[[], None]]() @implementer(IResolverSimple) class FakeResolver: @@ -272,10 +302,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ Make the callback fire in the next reactor iteration. """ - d = Deferred() - d.addCallback(lambda x: callback(*args, **kwargs)) - self.callLater(0, d.callback, True) - return d + cb = lambda: callback(*args, **kwargs) + # it's not safe to call callLater() here, so we append the callback to a + # separate queue. + self._thread_callbacks.append(cb) def getThreadPool(self): return self.threadpool @@ -303,6 +333,30 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return conn + def advance(self, amount): + # first advance our reactor's time, and run any "callLater" callbacks that + # makes ready + super().advance(amount) + + # now run any "callFromThread" callbacks + while True: + try: + callback = self._thread_callbacks.popleft() + except IndexError: + break + callback() + + # check for more "callLater" callbacks added by the thread callback + # This isn't required in a regular reactor, but it ends up meaning that + # our database queries can complete in a single call to `advance` [1] which + # simplifies tests. + # + # [1]: we replace the threadpool backing the db connection pool with a + # mock ThreadPool which doesn't really use threads; but we still use + # reactor.callFromThread to feed results back from the db functions to the + # main thread. + super().advance(0) + class ThreadPool: """ @@ -339,8 +393,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): """ server = _sth(cleanup_func, *args, **kwargs) - database = server.config.database.get_single_database() - # Make the thread pool synchronous. clock = server.get_clock() @@ -354,7 +406,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runWithConnection, func, *args, - **kwargs + **kwargs, ) def runInteraction(interaction, *args, **kwargs): @@ -364,7 +416,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runInteraction, interaction, *args, - **kwargs + **kwargs, ) pool.runWithConnection = runWithConnection @@ -372,6 +424,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool.threadpool = ThreadPool(clock._reactor) pool.running = True + # We've just changed the Databases to run DB transactions on the same + # thread, so we need to disable the dedicated thread behaviour. + server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False + return server @@ -541,12 +597,10 @@ def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol reactor factory: The connecting factory to build. """ - factory = reactor.tcpClients[client_id][2] + factory = reactor.tcpClients.pop(client_id)[2] client = factory.buildProtocol(None) server = AccumulatingProtocol() server.makeConnection(FakeTransport(client, reactor)) client.makeConnection(FakeTransport(server, reactor)) - reactor.tcpClients.pop(client_id) - return client, server diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 872039c8f1..4dd5a36178 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py
@@ -70,29 +70,26 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): the notice URL + an authentication code. """ # Initial sync, to get the user consent room invite - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=self.access_token ) - self.render(request) self.assertEqual(channel.code, 200) # Get the Room ID to join room_id = list(channel.json_body["rooms"]["invite"].keys())[0] # Join the room - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/rooms/" + room_id + "/join", access_token=self.access_token, ) - self.render(request) self.assertEqual(channel.code, 200) # Sync again, to get the message in the room - request, channel = self.make_request( + channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=self.access_token ) - self.render(request) self.assertEqual(channel.code, 200) # Get the message diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 6382b19dc3..fea54464af 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -305,8 +305,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.register_user("user", "password") tok = self.login("user", "password") - request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) - self.render(request) + channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) invites = channel.json_body["rooms"]["invite"] self.assertEqual(len(invites), 0, invites) @@ -319,8 +318,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): # Sync again to retrieve the events in the room, so we can check whether this # room has a notice in it. - request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) - self.render(request) + channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) # Scan the events in the room to search for a message from the server notices # user. @@ -355,10 +353,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): tok = self.login(localpart, "password") # Sync with the user's token to mark the user as active. - request, channel = self.make_request( - "GET", "/sync?timeout=0", access_token=tok, - ) - self.render(request) + channel = self.make_request("GET", "/sync?timeout=0", access_token=tok,) # Also retrieves the list of invites for this user. We don't care about that # one except if we're processing the last user, which should have received an diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index ad9bbef9d2..77c72834f2 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event from synapse.events import make_event_from_dict -from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store +from synapse.state.v2 import ( + _get_auth_chain_difference, + lexicographical_topological_sort, + resolve_events_with_store, +) from synapse.types import EventID from tests import unittest @@ -84,7 +88,7 @@ class FakeEvent: event_dict = { "auth_events": [(a, {}) for a in auth_events], "prev_events": [(p, {}) for p in prev_events], - "event_id": self.node_id, + "event_id": self.event_id, "sender": self.sender, "type": self.type, "content": self.content, @@ -377,6 +381,61 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) + def test_mainline_sort(self): + """Tests that the mainline ordering works correctly. + """ + + events = [ + FakeEvent( + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="PA1", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100, BOB: 50}}, + ), + FakeEvent( + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="PA2", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={ + "users": {ALICE: 100, BOB: 50}, + "events": {EventTypes.PowerLevels: 100}, + }, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100, BOB: 50}}, + ), + FakeEvent( + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + ] + + edges = [ + ["END", "T3", "PA2", "T2", "PA1", "T1", "START"], + ["END", "T4", "PB", "PA1"], + ] + + # We expect T3 to be picked as the other topics are pointing at older + # power levels. Note that without mainline ordering we'd pick T4 due to + # it being sent *after* T3. + expected_state_ids = ["T3", "PA2"] + + self.do_check(events, edges, expected_state_ids) + def do_check(self, events, edges, expected_state_ids): """Take a list of events and edges and calculate the state of the graph at END, and asserts it matches `expected_state_ids` @@ -587,6 +646,134 @@ class SimpleParamStateTestCase(unittest.TestCase): self.assert_dict(self.expected_combined_state, state) +class AuthChainDifferenceTestCase(unittest.TestCase): + """We test that `_get_auth_chain_difference` correctly handles unpersisted + events. + """ + + def test_simple(self): + # Test getting the auth difference for a simple chain with a single + # unpersisted event: + # + # Unpersisted | Persisted + # | + # C -|-> B -> A + + a = FakeEvent( + id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([], []) + + b = FakeEvent( + id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([a.event_id], []) + + c = FakeEvent( + id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([b.event_id], []) + + persisted_events = {a.event_id: a, b.event_id: b} + unpersited_events = {c.event_id: c} + + state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}] + + store = TestStateResolutionStore(persisted_events) + + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) + difference = self.successResultOf(defer.ensureDeferred(diff_d)) + + self.assertEqual(difference, {c.event_id}) + + def test_multiple_unpersisted_chain(self): + # Test getting the auth difference for a simple chain with multiple + # unpersisted events: + # + # Unpersisted | Persisted + # | + # D -> C -|-> B -> A + + a = FakeEvent( + id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([], []) + + b = FakeEvent( + id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([a.event_id], []) + + c = FakeEvent( + id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([b.event_id], []) + + d = FakeEvent( + id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([c.event_id], []) + + persisted_events = {a.event_id: a, b.event_id: b} + unpersited_events = {c.event_id: c, d.event_id: d} + + state_sets = [ + {"a": a.event_id, "b": b.event_id}, + {"c": c.event_id, "d": d.event_id}, + ] + + store = TestStateResolutionStore(persisted_events) + + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) + difference = self.successResultOf(defer.ensureDeferred(diff_d)) + + self.assertEqual(difference, {d.event_id, c.event_id}) + + def test_unpersisted_events_different_sets(self): + # Test getting the auth difference for with multiple unpersisted events + # in different branches: + # + # Unpersisted | Persisted + # | + # D --> C -|-> B -> A + # E ----^ -|---^ + # | + + a = FakeEvent( + id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([], []) + + b = FakeEvent( + id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([a.event_id], []) + + c = FakeEvent( + id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([b.event_id], []) + + d = FakeEvent( + id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([c.event_id], []) + + e = FakeEvent( + id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([c.event_id, b.event_id], []) + + persisted_events = {a.event_id: a, b.event_id: b} + unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e} + + state_sets = [ + {"a": a.event_id, "b": b.event_id, "e": e.event_id}, + {"c": c.event_id, "d": d.event_id}, + ] + + store = TestStateResolutionStore(persisted_events) + + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) + difference = self.successResultOf(defer.ensureDeferred(diff_d)) + + self.assertEqual(difference, {d.event_id, e.event_id}) + + def pairwise(iterable): "s -> (s0,s1), (s1,s2), (s2, s3), ..." a, b = itertools.tee(iterable) @@ -647,7 +834,7 @@ class TestStateResolutionStore: return list(result) - def get_auth_chain_difference(self, auth_sets): + def get_auth_chain_difference(self, room_id, auth_sets): chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index f5afed017c..1ac4ebc61d 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py
@@ -15,308 +15,9 @@ # limitations under the License. -from mock import Mock - -from twisted.internet import defer - -from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import Cache, cached - from tests import unittest -class CacheTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): - self.cache = Cache("test") - - def test_empty(self): - failed = False - try: - self.cache.get("foo") - except KeyError: - failed = True - - self.assertTrue(failed) - - def test_hit(self): - self.cache.prefill("foo", 123) - - self.assertEquals(self.cache.get("foo"), 123) - - def test_invalidate(self): - self.cache.prefill(("foo",), 123) - self.cache.invalidate(("foo",)) - - failed = False - try: - self.cache.get(("foo",)) - except KeyError: - failed = True - - self.assertTrue(failed) - - def test_eviction(self): - cache = Cache("test", max_entries=2) - - cache.prefill(1, "one") - cache.prefill(2, "two") - cache.prefill(3, "three") # 1 will be evicted - - failed = False - try: - cache.get(1) - except KeyError: - failed = True - - self.assertTrue(failed) - - cache.get(2) - cache.get(3) - - def test_eviction_lru(self): - cache = Cache("test", max_entries=2) - - cache.prefill(1, "one") - cache.prefill(2, "two") - - # Now access 1 again, thus causing 2 to be least-recently used - cache.get(1) - - cache.prefill(3, "three") - - failed = False - try: - cache.get(2) - except KeyError: - failed = True - - self.assertTrue(failed) - - cache.get(1) - cache.get(3) - - -class CacheDecoratorTestCase(unittest.HomeserverTestCase): - @defer.inlineCallbacks - def test_passthrough(self): - class A: - @cached() - def func(self, key): - return key - - a = A() - - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals((yield a.func("bar")), "bar") - - @defer.inlineCallbacks - def test_hit(self): - callcount = [0] - - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key - - a = A() - yield a.func("foo") - - self.assertEquals(callcount[0], 1) - - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals(callcount[0], 1) - - @defer.inlineCallbacks - def test_invalidate(self): - callcount = [0] - - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key - - a = A() - yield a.func("foo") - - self.assertEquals(callcount[0], 1) - - a.func.invalidate(("foo",)) - - yield a.func("foo") - - self.assertEquals(callcount[0], 2) - - def test_invalidate_missing(self): - class A: - @cached() - def func(self, key): - return key - - A().func.invalidate(("what",)) - - @defer.inlineCallbacks - def test_max_entries(self): - callcount = [0] - - class A: - @cached(max_entries=10) - def func(self, key): - callcount[0] += 1 - return key - - a = A() - - for k in range(0, 12): - yield a.func(k) - - self.assertEquals(callcount[0], 12) - - # There must have been at least 2 evictions, meaning if we calculate - # all 12 values again, we must get called at least 2 more times - for k in range(0, 12): - yield a.func(k) - - self.assertTrue( - callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0]) - ) - - def test_prefill(self): - callcount = [0] - - d = defer.succeed(123) - - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return d - - a = A() - - a.func.prefill(("foo",), ObservableDeferred(d)) - - self.assertEquals(a.func("foo").result, d.result) - self.assertEquals(callcount[0], 0) - - @defer.inlineCallbacks - def test_invalidate_context(self): - callcount = [0] - callcount2 = [0] - - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key - - @cached(cache_context=True) - def func2(self, key, cache_context): - callcount2[0] += 1 - return self.func(key, on_invalidate=cache_context.invalidate) - - a = A() - yield a.func2("foo") - - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) - - a.func.invalidate(("foo",)) - yield a.func("foo") - - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 1) - - yield a.func2("foo") - - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) - - @defer.inlineCallbacks - def test_eviction_context(self): - callcount = [0] - callcount2 = [0] - - class A: - @cached(max_entries=2) - def func(self, key): - callcount[0] += 1 - return key - - @cached(cache_context=True) - def func2(self, key, cache_context): - callcount2[0] += 1 - return self.func(key, on_invalidate=cache_context.invalidate) - - a = A() - yield a.func2("foo") - yield a.func2("foo2") - - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) - - yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) - - yield a.func("foo3") - - self.assertEquals(callcount[0], 3) - self.assertEquals(callcount2[0], 2) - - yield a.func2("foo") - - self.assertEquals(callcount[0], 4) - self.assertEquals(callcount2[0], 3) - - @defer.inlineCallbacks - def test_double_get(self): - callcount = [0] - callcount2 = [0] - - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key - - @cached(cache_context=True) - def func2(self, key, cache_context): - callcount2[0] += 1 - return self.func(key, on_invalidate=cache_context.invalidate) - - a = A() - a.func2.cache.cache = Mock(wraps=a.func2.cache.cache) - - yield a.func2("foo") - - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) - - a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 1) - - yield a.func2("foo") - a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 2) - - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 2) - - a.func.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 3) - yield a.func("foo") - - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) - - yield a.func2("foo") - - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 3) - - class UpsertManyTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.storage = hs.get_datastore() diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 46f94914ff..1ce29af5fd 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py
@@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): # must be done after inserts database = hs.get_datastores().databases[0] self.store = ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, make_conn(database._database_config, database.engine, "test"), hs ) def tearDown(self): @@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): db_config = hs.config.get_single_database() self.store = TestTransactionStore( - database, make_conn(db_config, self.engine), hs + database, make_conn(db_config, self.engine, "test"), hs ) def _add_service(self, url, as_token, id): @@ -244,7 +244,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) @@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._insert_txn(service.id, 9644, events) yield self._insert_txn(service.id, 9645, events) txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) @@ -270,7 +270,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): events = [Mock(event_id="e1"), Mock(event_id="e2")] yield self._set_last_txn(service.id, 9643) txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) @@ -293,7 +293,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._insert_txn(self.as_list[3]["id"], 9643, events) txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) @@ -410,6 +410,62 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) +class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs + + def prepare(self, hs, reactor, clock): + self.service = Mock(id="foo") + self.store = self.hs.get_datastore() + self.get_success(self.store.set_appservice_state(self.service, "up")) + + def test_get_type_stream_id_for_appservice_no_value(self): + value = self.get_success( + self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") + ) + self.assertEquals(value, 0) + + value = self.get_success( + self.store.get_type_stream_id_for_appservice(self.service, "presence") + ) + self.assertEquals(value, 0) + + def test_get_type_stream_id_for_appservice_invalid_type(self): + self.get_failure( + self.store.get_type_stream_id_for_appservice(self.service, "foobar"), + ValueError, + ) + + def test_set_type_stream_id_for_appservice(self): + read_receipt_value = 1024 + self.get_success( + self.store.set_type_stream_id_for_appservice( + self.service, "read_receipt", read_receipt_value + ) + ) + result = self.get_success( + self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") + ) + self.assertEqual(result, read_receipt_value) + + self.get_success( + self.store.set_type_stream_id_for_appservice( + self.service, "presence", read_receipt_value + ) + ) + result = self.get_success( + self.store.get_type_stream_id_for_appservice(self.service, "presence") + ) + self.assertEqual(result, read_receipt_value) + + def test_set_type_stream_id_for_appservice_invalid_type(self): + self.get_failure( + self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024), + ValueError, + ) + + # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -448,7 +504,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): database = hs.get_datastores().databases[0] ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, make_conn(database._database_config, database.engine, "test"), hs ) @defer.inlineCallbacks @@ -467,7 +523,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): with self.assertRaises(ConfigError) as cm: database = hs.get_datastores().databases[0] ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, + make_conn(database._database_config, database.engine, "test"), + hs, ) e = cm.exception @@ -491,7 +549,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): with self.assertRaises(ConfigError) as cm: database = hs.get_datastores().databases[0] ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, + make_conn(database._database_config, database.engine, "test"), + hs, ) e = cm.exception diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 080761d1d2..c13a57dad1 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -22,7 +22,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client.v1 import login, room from synapse.storage import prepare_database -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") - self.requester = Requester(self.user, None, False, False, None, None) + self.requester = create_requester(self.user) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] @@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") - self.requester = Requester(self.user, None, False, False, None, None) + self.requester = create_requester(self.user) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() @@ -309,36 +309,6 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): ) self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) - @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) - def test_send_dummy_event_without_consent(self): - self._create_extremity_rich_graph() - self._enable_consent_checking() - - # Pump the reactor repeatedly so that the background updates have a - # chance to run. Attempt to add dummy event with user that has not consented - # Check that dummy event send fails. - self.pump(10 * 60) - latest_event_ids = self.get_success( - self.store.get_latest_event_ids_in_room(self.room_id) - ) - self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT) - - # Create new user, and add consent - user2 = self.register_user("user2", "password") - token2 = self.login("user2", "password") - self.get_success( - self.store.user_set_consent_version(user2, self.CONSENT_VERSION) - ) - self.helper.join(self.room_id, user2, tok=token2) - - # Background updates should now cause a dummy event to be added to the graph - self.pump(10 * 60) - - latest_event_ids = self.get_success( - self.store.get_latest_event_ids_in_room(self.room_id) - ) - self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) - @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250) def test_expiry_logic(self): """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 755c70db31..a69117c5a9 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.server import make_request from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -408,18 +409,18 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): # Advance to a known time self.reactor.advance(123456 - self.reactor.seconds()) - request, channel = self.make_request( + headers1 = {b"User-Agent": b"Mozzila pizza"} + headers1.update(headers) + + make_request( + self.reactor, + self.site, "GET", - "/_matrix/client/r0/admin/users/" + self.user_id, + "/_synapse/admin/v1/users/" + self.user_id, access_token=access_token, - **make_request_args + custom_headers=headers1.items(), + **make_request_args, ) - request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") - - # Add the optional headers - for h, v in headers.items(): - request.requestHeaders.addRawHeader(h, v) - self.render(request) # Advance so the save loop occurs self.reactor.advance(100) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index ecb00f4e02..dabc1c5f09 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py
@@ -80,6 +80,32 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) @defer.inlineCallbacks + def test_count_devices_by_users(self): + yield defer.ensureDeferred( + self.store.store_device("user_id", "device1", "display_name 1") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id", "device2", "display_name 2") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id2", "device3", "display_name 3") + ) + + res = yield defer.ensureDeferred(self.store.count_devices_by_users()) + self.assertEqual(0, res) + + res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"])) + self.assertEqual(0, res) + + res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"])) + self.assertEqual(2, res) + + res = yield defer.ensureDeferred( + self.store.count_devices_by_users(["user_id", "user_id2"]) + ) + self.assertEqual(3, res) + + @defer.inlineCallbacks def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 35dafbb904..3d7760d5d9 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py
@@ -26,7 +26,7 @@ room_key = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver("server", http_client=None) + hs = self.setup_test_homeserver("server", federation_http_client=None) self.store = hs.get_datastore() return hs diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d4c3b867e3..482506d731 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py
@@ -202,34 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # Now actually test that various combinations give the right result: difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]) ) self.assertSetEqual(difference, {"a", "b"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}]) ) self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}]) ) self.assertSetEqual(difference, {"a", "b", "c"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}]) ) self.assertSetEqual(difference, {"a", "b", "d", "e"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}]) ) self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}]) ) self.assertSetEqual(difference, {"a", "b"}) - difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}]) + ) self.assertSetEqual(difference, set()) diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 3957471f3f..7691f2d790 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py
@@ -14,7 +14,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, generate_latest -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): room_creator = self.hs.get_room_creation_handler() user = UserID("alice", "test") - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) # Real events, forward extremities events = [(3, 2), (6, 2), (4, 6)] diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py new file mode 100644
index 0000000000..71210ce606 --- /dev/null +++ b/tests/storage/test_events.py
@@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import RoomVersions +from synapse.federation.federation_base import event_from_pdu_json +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.unittest import HomeserverTestCase + + +class ExtremPruneTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.state = self.hs.get_state_handler() + self.persistence = self.hs.get_storage().persistence + self.store = self.hs.get_datastore() + + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=self.token + ) + + body = self.helper.send(self.room_id, body="Test", tok=self.token) + local_message_event_id = body["event_id"] + + # Fudge a remote event and persist it. This will be the extremity before + # the gap. + self.remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Message, + "state_key": "@user:other", + "content": {}, + "room_id": self.room_id, + "sender": "@user:other", + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + self.persist_event(self.remote_event_1) + + # Check that the current extremities is the remote event. + self.assert_extremities([self.remote_event_1.event_id]) + + def persist_event(self, event, state=None): + """Persist the event, with optional state + """ + context = self.get_success( + self.state.compute_event_context(event, old_state=state) + ) + self.get_success(self.persistence.persist_event(event, context)) + + def assert_extremities(self, expected_extremities): + """Assert the current extremities for the room + """ + extremities = self.get_success( + self.store.get_prev_events_for_room(self.room_id) + ) + self.assertCountEqual(extremities, expected_extremities) + + def test_prune_gap(self): + """Test that we drop extremities after a gap when we see an event from + the same domain. + """ + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other", + "depth": 50, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id]) + + def test_do_not_prune_gap_if_state_different(self): + """Test that we don't prune extremities after a gap if the resolved + state is different. + """ + + # Fudge a second event which points to an event we don't have. + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Message, + "state_key": "@user:other", + "content": {}, + "room_id": self.room_id, + "sender": "@user:other", + "depth": 10, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + # Now we persist it with state with a dropped history visibility + # setting. The state resolution across the old and new event will then + # include it, and so the resolved state won't match the new state. + state_before_gap = dict( + self.get_success(self.state.get_current_state(self.room_id)) + ) + state_before_gap.pop(("m.room.history_visibility", "")) + + context = self.get_success( + self.state.compute_event_context( + remote_event_2, old_state=state_before_gap.values() + ) + ) + + self.get_success(self.persistence.persist_event(remote_event_2, context)) + + # Check that we haven't dropped the old extremity. + self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) + + def test_prune_gap_if_old(self): + """Test that we drop extremities after a gap when the previous extremity + is "old" + """ + + # Advance the clock for many days to make the old extremity "old". We + # also set the depth to "lots". + self.reactor.advance(7 * 24 * 60 * 60) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id]) + + def test_do_not_prune_gap_if_other_server(self): + """Test that we do not drop extremities after a gap when we see an event + from a different domain. + """ + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) + + def test_prune_gap_if_dummy_remote(self): + """Test that we drop extremities after a gap when the previous extremity + is a local dummy event and only points to remote events. + """ + + body = self.helper.send_event( + self.room_id, type=EventTypes.Dummy, content={}, tok=self.token + ) + local_message_event_id = body["event_id"] + self.assert_extremities([local_message_event_id]) + + # Advance the clock for many days to make the old extremity "old". We + # also set the depth to "lots". + self.reactor.advance(7 * 24 * 60 * 60) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id]) + + def test_prune_gap_if_dummy_local(self): + """Test that we don't drop extremities after a gap when the previous + extremity is a local dummy event and points to local events. + """ + + body = self.helper.send(self.room_id, body="Test", tok=self.token) + + body = self.helper.send_event( + self.room_id, type=EventTypes.Dummy, content={}, tok=self.token + ) + local_message_event_id = body["event_id"] + self.assert_extremities([local_message_event_id]) + + # Advance the clock for many days to make the old extremity "old". We + # also set the depth to "lots". + self.reactor.advance(7 * 24 * 60 * 60) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([remote_event_2.event_id, local_message_event_id]) + + def test_do_not_prune_gap_if_not_dummy(self): + """Test that we do not drop extremities after a gap when the previous extremity + is not a dummy event. + """ + + body = self.helper.send(self.room_id, body="test", tok=self.token) + local_message_event_id = body["event_id"] + self.assert_extremities([local_message_event_id]) + + # Fudge a second event which points to an event we don't have. This is a + # state event so that the state changes (otherwise we won't prune the + # extremity as they'll have the same state group). + remote_event_2 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": "@user:other2", + "content": {"membership": Membership.JOIN}, + "room_id": self.room_id, + "sender": "@user:other2", + "depth": 10000, + "prev_events": ["$some_unknown_message"], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + + self.persist_event(remote_event_2, state=state_before_gap.values()) + + # Check the new extremity is just the new remote event. + self.assert_extremities([local_message_event_id, remote_event_2.event_id]) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 392b08832b..cc0612cf65 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py
@@ -199,10 +199,17 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): first_id_gen = self._create_id_generator("first", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"]) - self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) + # The first ID gen will notice that it can advance its token to 7 as it + # has no in progress writes... + self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) + # ... but the second ID gen doesn't know that. + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) + # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -211,7 +218,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual( - first_id_gen.get_positions(), {"first": 3, "second": 7} + first_id_gen.get_positions(), {"first": 7, "second": 7} ) self.get_success(_get_next_async()) @@ -279,7 +286,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first", writers=["first", "second"]) + id_gen = self._create_id_generator("worker", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) @@ -319,14 +326,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator("first", writers=["first", "second"]) - self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) + self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5}) - self.assertEqual(id_gen.get_persisted_upto_position(), 3) + self.assertEqual(id_gen.get_persisted_upto_position(), 5) async def _get_next_async(): async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 6) - self.assertEqual(id_gen.get_persisted_upto_position(), 3) + self.assertEqual(id_gen.get_persisted_upto_position(), 5) self.get_success(_get_next_async()) @@ -388,7 +395,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("second", 5) # Initial config has two writers - id_gen = self._create_id_generator("first", writers=["first", "second"]) + id_gen = self._create_id_generator("worker", writers=["first", "second"]) self.assertEqual(id_gen.get_persisted_upto_position(), 3) self.assertEqual(id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(id_gen.get_current_token_for_writer("second"), 5) @@ -568,7 +575,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(_get_next_async2()) - self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) + self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 7e7f1286d9..e9e3bca3bf 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py
@@ -48,3 +48,10 @@ class DataStoreTestCase(unittest.TestCase): self.assertEquals(1, total) self.assertEquals(self.displayname, users.pop()["displayname"]) + + users, total = yield defer.ensureDeferred( + self.store.get_users_paginate(0, 10, name="BC", guests=False) + ) + + self.assertEquals(1, total) + self.assertEquals(self.displayname, users.pop()["displayname"]) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index cc1f3c53c5..a06ad2c03e 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py
@@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase): servlets = [room.register_servlets] def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver("server", http_client=None) + hs = self.setup_test_homeserver("server", federation_http_client=None) return hs def prepare(self, reactor, clock, hs): diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 1ea35d60c1..a6303bf0ee 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py
@@ -14,9 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from mock import Mock - from canonicaljson import json from twisted.internet import defer @@ -30,12 +27,10 @@ from tests.utils import create_room class RedactionTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - config = self.default_config() + def default_config(self): + config = super().default_config() config["redaction_retention_period"] = "30d" - return self.setup_test_homeserver( - resource_for_federation=Mock(), http_client=None, config=config - ) + return config def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() @@ -236,9 +231,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): self._event_id = event_id @defer.inlineCallbacks - def build(self, prev_event_ids): + def build(self, prev_event_ids, auth_event_ids): built_event = yield defer.ensureDeferred( - self._base_builder.build(prev_event_ids) + self._base_builder.build(prev_event_ids, auth_event_ids) ) built_event._event_id = self._event_id diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 6b582771fe..c8c7a90e5d 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py
@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase): self.store.get_user_by_access_token(self.tokens[1]) ) - self.assertDictContainsSubset( - {"name": self.user_id, "device_id": self.device_id}, result - ) - - self.assertTrue("token_id" in result) + self.assertEqual(result.user_id, self.user_id) + self.assertEqual(result.device_id, self.device_id) + self.assertIsNotNone(result.token_id) @defer.inlineCallbacks def test_user_delete_access_tokens(self): @@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase): user = yield defer.ensureDeferred( self.store.get_user_by_access_token(self.tokens[0]) ) - self.assertEqual(self.user_id, user["name"]) + self.assertEqual(self.user_id, user.user_id) # now delete the rest yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 12ccc1f53e..d2aed66f6d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -14,12 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock - from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests import unittest from tests.test_utils import event_injection @@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver( - resource_for_federation=Mock(), http_client=None - ) - return hs - def prepare(self, reactor, clock, hs: TestHomeServer): # We can't test the RoomMemberStore on its own without the other event @@ -187,7 +179,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Now let's create a room, which will insert a membership user = UserID("alice", "test") - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) self.get_success(self.room_creator.create_room(requester, {})) # Register the background update to run again. diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 738e912468..a6f63f4aaf 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py
@@ -21,6 +21,8 @@ from tests.utils import setup_test_homeserver ALICE = "@alice:a" BOB = "@bob:b" BOBBY = "@bobby:a" +# The localpart isn't 'Bela' on purpose so we can test looking up display names. +BELA = "@somenickname:a" class UserDirectoryStoreTestCase(unittest.TestCase): @@ -41,6 +43,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase): self.store.update_profile_in_user_dir(BOBBY, "bobby", None) ) yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(BELA, "Bela", None) + ) + yield defer.ensureDeferred( self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) ) @@ -72,3 +77,21 @@ class UserDirectoryStoreTestCase(unittest.TestCase): ) finally: self.hs.config.user_directory_search_all_users = False + + @defer.inlineCallbacks + def test_search_user_dir_stop_words(self): + """Tests that a user can look up another user by searching for the start if its + display name even if that name happens to be a common English word that would + usually be ignored in full text searches. + """ + self.hs.config.user_directory_search_all_users = True + try: + r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, + ) + finally: + self.hs.config.user_directory_search_all_users = False diff --git a/tests/test_federation.py b/tests/test_federation.py
index 27a7fc9ed7..fc9aab32d0 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -20,7 +20,7 @@ from twisted.internet.defer import succeed from synapse.api.errors import FederationError from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination @@ -37,13 +37,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, - http_client=self.http_client, + federation_http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor, ) user_id = UserID("us", "test") - our_user = Requester(user_id, None, False, False, None, None) + our_user = create_requester(user_id) room_creator = self.homeserver.get_room_creation_handler() self.room_id = self.get_success( room_creator.create_room( @@ -75,7 +75,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - self.handler = self.homeserver.get_handlers().federation_handler + self.handler = self.homeserver.get_federation_handler() self.handler.do_auth = lambda origin, event, context, auth_events: succeed( context ) @@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - with LoggingContext(request="lying_event"): + with LoggingContext(): failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True diff --git a/tests/test_mau.py b/tests/test_mau.py
index 654a6fa42d..51660b51d5 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py
@@ -19,6 +19,7 @@ import json from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError +from synapse.appservice import ApplicationService from synapse.rest.client.v2_alpha import register, sync from tests import unittest @@ -75,6 +76,45 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + def test_as_ignores_mau(self): + """Test that application services can still create users when the MAU + limit has been reached. This only works when application service + user ip tracking is disabled. + """ + + # Create and sync so that the MAU counts get updated + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + + # check we're testing what we think we are: there should be two active users + self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2) + + # We've created and activated two users, we shouldn't be able to + # register new users + with self.assertRaises(SynapseError) as cm: + self.create_user("kermit3") + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + # Cheekily add an application service that we use to register a new user + # with. + as_token = "foobartoken" + self.store.services_cache.append( + ApplicationService( + token=as_token, + hostname=self.hs.hostname, + id="SomeASID", + sender="@as_sender:test", + namespaces={"users": [{"regex": "@as_*", "exclusive": True}]}, + ) + ) + + self.create_user("as_kermit4", token=as_token) + def test_allowed_after_a_month_mau(self): # Create and sync so that the MAU counts get updated token1 = self.create_user("kermit1") @@ -192,7 +232,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.reactor.advance(100) self.assertEqual(2, self.successResultOf(count)) - def create_user(self, localpart): + def create_user(self, localpart, token=None): request_data = json.dumps( { "username": localpart, @@ -201,8 +241,9 @@ class TestMauLimit(unittest.HomeserverTestCase): } ) - request, channel = self.make_request("POST", "/register", request_data) - self.render(request) + channel = self.make_request( + "POST", "/register", request_data, access_token=token, + ) if channel.code != 200: raise HttpResponseException( @@ -214,8 +255,7 @@ class TestMauLimit(unittest.HomeserverTestCase): return access_token def do_sync_for_user(self, token): - request, channel = self.make_request("GET", "/sync", access_token=token) - self.render(request) + channel = self.make_request("GET", "/sync", access_token=token) if channel.code != 200: raise HttpResponseException( diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index f5f63d8ed6..759e4cd048 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py
@@ -15,7 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.deferred_cache import DeferredCache from tests import unittest @@ -138,7 +138,7 @@ class CacheMetricsTests(unittest.HomeserverTestCase): Caches produce metrics reflecting their state when scraped. """ CACHE_NAME = "cache_metrics_test_fgjkbdfg" - cache = Cache(CACHE_NAME, max_entries=777) + cache = DeferredCache(CACHE_NAME, max_entries=777) items = { x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index 7657bddea5..e7aed092c2 100644 --- a/tests/test_phone_home.py +++ b/tests/test_phone_home.py
@@ -17,7 +17,7 @@ import resource import mock -from synapse.app.homeserver import phone_stats_home +from synapse.app.phone_stats_home import phone_stats_home from tests.unittest import HomeserverTestCase diff --git a/tests/test_preview.py b/tests/test_preview.py
index 7f67ee9e1f..a883d707df 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py
@@ -56,7 +56,7 @@ class PreviewTestCase(unittest.TestCase): desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -69,7 +69,7 @@ class PreviewTestCase(unittest.TestCase): desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø lies in Northern Norway. The municipality has a population of" " (2015) 72,066, but with an annual influx of students it has over 75,000" @@ -96,7 +96,7 @@ class PreviewTestCase(unittest.TestCase): desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -122,7 +122,7 @@ class PreviewTestCase(unittest.TestCase): ] desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) - self.assertEquals( + self.assertEqual( desc, "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" @@ -149,7 +149,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment(self): html = """ @@ -164,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment2(self): html = """ @@ -182,7 +182,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals( + self.assertEqual( og, { "og:title": "Foo", @@ -203,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_missing_title(self): html = """ @@ -216,7 +216,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": None, "og:description": "Some text."}) + self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) def test_h1_as_title(self): html = """ @@ -230,7 +230,7 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."}) + self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) def test_missing_title_and_broken_h1(self): html = """ @@ -244,4 +244,9 @@ class PreviewUrlTestCase(unittest.TestCase): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {"og:title": None, "og:description": "Some text."}) + self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) + + def test_empty(self): + html = "" + og = decode_and_calc_og(html, "http://example.com/test.html") + self.assertEqual(og, {}) diff --git a/tests/test_server.py b/tests/test_server.py
index 655c918a15..815da18e65 100644 --- a/tests/test_server.py +++ b/tests/test_server.py
@@ -26,9 +26,9 @@ from synapse.util import Clock from tests import unittest from tests.server import ( + FakeSite, ThreadedMemoryReactorClock, make_request, - render, setup_test_homeserver, ) @@ -38,7 +38,10 @@ class JsonResourceTests(unittest.TestCase): self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( - self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor + self.addCleanup, + federation_http_client=None, + clock=self.hs_clock, + reactor=self.reactor, ) def test_handler_for_request(self): @@ -61,12 +64,10 @@ class JsonResourceTests(unittest.TestCase): "test_servlet", ) - request, channel = make_request( - self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" + make_request( + self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" ) - render(request, res, self.reactor) - self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]}) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) def test_callback_direct_exception(self): @@ -83,8 +84,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") - render(request, res, self.reactor) + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -108,8 +108,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") - render(request, res, self.reactor) + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"500") @@ -127,8 +126,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") - render(request, res, self.reactor) + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") @@ -150,8 +148,7 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") - render(request, res, self.reactor) + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar") self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") @@ -173,8 +170,7 @@ class JsonResourceTests(unittest.TestCase): ) # The path was registered as GET, but this is a HEAD request. - request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo") - render(request, res, self.reactor) + channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) @@ -196,9 +192,6 @@ class OptionsResourceTests(unittest.TestCase): def _make_request(self, method, path): """Create a request from the method/path and return a channel with the response.""" - request, channel = make_request(self.reactor, method, path, shorthand=False) - request.prepath = [] # This doesn't get set properly by make_request. - # Create a site and query for the resource. site = SynapseSite( "test", @@ -207,11 +200,9 @@ class OptionsResourceTests(unittest.TestCase): self.resource, "1.0", ) - request.site = site - resource = site.getResourceFor(request) - # Finally, render the resource and return the channel. - render(request, resource, self.reactor) + # render the request and return the channel + channel = make_request(self.reactor, site, method, path, shorthand=False) return channel def test_unknown_options_request(self): @@ -284,8 +275,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"GET", b"/path") - render(request, res, self.reactor) + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"200") body = channel.result["body"] @@ -303,8 +293,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"GET", b"/path") - render(request, res, self.reactor) + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"301") headers = channel.result["headers"] @@ -325,8 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"GET", b"/path") - render(request, res, self.reactor) + channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") self.assertEqual(channel.result["code"], b"304") headers = channel.result["headers"] @@ -345,8 +333,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - request, channel = make_request(self.reactor, b"HEAD", b"/path") - render(request, res, self.reactor) + channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) diff --git a/tests/test_state.py b/tests/test_state.py
index 80b0ccbc40..6227a3ba95 100644 --- a/tests/test_state.py +++ b/tests/test_state.py
@@ -169,6 +169,7 @@ class StateTestCase(unittest.TestCase): "get_state_handler", "get_clock", "get_state_resolution_handler", + "hostname", ] ) hs.config = default_config("tesths", True) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index b89798336c..a743cdc3a9 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py
@@ -53,8 +53,7 @@ class TermsTestCase(unittest.HomeserverTestCase): def test_ui_auth(self): # Do a UI auth request request_data = json.dumps({"username": "kermit", "password": "monkey"}) - request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) + channel = self.make_request(b"POST", self.url, request_data) self.assertEquals(channel.result["code"], b"401", channel.result) @@ -97,8 +96,7 @@ class TermsTestCase(unittest.HomeserverTestCase): self.registration_handler.check_username = Mock(return_value=True) - request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) + channel = self.make_request(b"POST", self.url, request_data) # We don't bother checking that the response is correct - we'll leave that to # other tests. We just want to make sure we're on the right path. @@ -115,8 +113,7 @@ class TermsTestCase(unittest.HomeserverTestCase): }, } ) - request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) + channel = self.make_request(b"POST", self.url, request_data) # We're interested in getting a response that looks like a successful # registration, not so much that the details are exactly what we want. diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a298cc0fd3..43898d8142 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py
@@ -17,8 +17,17 @@ """ Utilities for running the unit tests """ +import sys +import warnings from asyncio import Future -from typing import Any, Awaitable, TypeVar +from typing import Any, Awaitable, Callable, TypeVar + +from mock import Mock + +import attr + +from twisted.python.failure import Failure +from twisted.web.client import ResponseDone TV = TypeVar("TV") @@ -48,3 +57,65 @@ def make_awaitable(result: Any) -> Awaitable[Any]: future = Future() # type: ignore future.set_result(result) return future + + +def setup_awaitable_errors() -> Callable[[], None]: + """ + Convert warnings from a non-awaited coroutines into errors. + """ + warnings.simplefilter("error", RuntimeWarning) + + # unraisablehook was added in Python 3.8. + if not hasattr(sys, "unraisablehook"): + return lambda: None + + # State shared between unraisablehook and check_for_unraisable_exceptions. + unraisable_exceptions = [] + orig_unraisablehook = sys.unraisablehook # type: ignore + + def unraisablehook(unraisable): + unraisable_exceptions.append(unraisable.exc_value) + + def cleanup(): + """ + A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. + """ + sys.unraisablehook = orig_unraisablehook # type: ignore + if unraisable_exceptions: + raise unraisable_exceptions.pop() + + sys.unraisablehook = unraisablehook # type: ignore + + return cleanup + + +def simple_async_mock(return_value=None, raises=None) -> Mock: + # AsyncMock is not available in python3.5, this mimics part of its behaviour + async def cb(*args, **kwargs): + if raises: + raise raises + return return_value + + return Mock(side_effect=cb) + + +@attr.s +class FakeResponse: + """A fake twisted.web.IResponse object + + there is a similar class at treq.test.test_response, but it lacks a `phrase` + attribute, and didn't support deliverBody until recently. + """ + + # HTTP response code + code = attr.ib(type=int) + + # HTTP response phrase (eg b'OK' for a 200) + phrase = attr.ib(type=bytes) + + # body of the response + body = attr.ib(type=bytes) + + def deliverBody(self, protocol): + protocol.dataReceived(self.body) + protocol.connectionLost(Failure(ResponseDone())) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index e93aa84405..c3c4a93e1f 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py
@@ -50,7 +50,7 @@ async def inject_member_event( sender=sender, state_key=target, content=content, - **kwargs + **kwargs, ) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index fdfb840b62..52ae5c5713 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py
@@ -48,7 +48,7 @@ def setup_logging(): handler = ToTwistedHandler() formatter = logging.Formatter(log_format) handler.setFormatter(formatter) - handler.addFilter(LoggingContextFilter(request="")) + handler.addFilter(LoggingContextFilter()) root_logger.addHandler(handler) log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR") diff --git a/tests/unittest.py b/tests/unittest.py
index e654c0442d..af7f752c5a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Optional, Tuple, Type, TypeVar, Union +from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union from mock import Mock, patch @@ -30,6 +30,7 @@ from twisted.internet.defer import Deferred, ensureDeferred, succeed from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool from twisted.trial import unittest +from twisted.web.resource import Resource from synapse.api.constants import EventTypes, Membership from synapse.config.homeserver import HomeServerConfig @@ -44,17 +45,12 @@ from synapse.logging.context import ( set_current_context, ) from synapse.server import HomeServer -from synapse.types import Requester, UserID, create_requester +from synapse.types import UserID, create_requester +from synapse.util.httpresourcetree import create_resource_tree from synapse.util.ratelimitutils import FederationRateLimiter -from tests.server import ( - FakeChannel, - get_clock, - make_request, - render, - setup_test_homeserver, -) -from tests.test_utils import event_injection +from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver +from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -119,6 +115,10 @@ class TestCase(unittest.TestCase): logging.getLogger().setLevel(level) + # Trial messes with the warnings configuration, thus this has to be + # done in the context of an individual TestCase. + self.addCleanup(setup_awaitable_errors()) + return orig() @around(self) @@ -235,13 +235,11 @@ class HomeserverTestCase(TestCase): if not isinstance(self.hs, HomeServer): raise Exception("A homeserver wasn't returned, but %r" % (self.hs,)) - # Register the resources - self.resource = self.create_test_json_resource() - - # create a site to wrap the resource. + # create the root resource, and a site to wrap it. + self.resource = self.create_test_resource() self.site = SynapseSite( logger_name="synapse.access.http.fake", - site_tag="test", + site_tag=self.hs.config.server.server_name, config=self.hs.config.server.listeners[0], resource=self.resource, server_version_string="1", @@ -249,22 +247,29 @@ class HomeserverTestCase(TestCase): from tests.rest.client.v1.utils import RestHelper - self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) + self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) if hasattr(self, "user_id"): if self.hijack_auth: + # We need a valid token ID to satisfy foreign key constraints. + token_id = self.get_success( + self.hs.get_datastore().add_access_token_to_user( + self.helper.auth_user_id, "some_fake_token", None, None, + ) + ) + async def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, + "token_id": token_id, "is_guest": False, } async def get_user_by_req(request, allow_guest=False, rights="access"): return create_requester( UserID.from_string(self.helper.auth_user_id), - 1, + token_id, False, False, None, @@ -312,22 +317,32 @@ class HomeserverTestCase(TestCase): hs = self.setup_test_homeserver() return hs - def create_test_json_resource(self): + def create_test_resource(self) -> Resource: """ - Create a test JsonResource, with the relevant servlets registerd to it + Create a the root resource for the test server. - The default implementation calls each function in `servlets` to do the - registration. - - Returns: - JsonResource: + The default calls `self.create_resource_dict` and builds the resultant dict + into a tree. """ - resource = JsonResource(self.hs) + root_resource = Resource() + create_resource_tree(self.create_resource_dict(), root_resource) + return root_resource - for servlet in self.servlets: - servlet(self.hs, resource) + def create_resource_dict(self) -> Dict[str, Resource]: + """Create a resource tree for the test server + + A resource tree is a mapping from path to twisted.web.resource. - return resource + The default implementation creates a JsonResource and calls each function in + `servlets` to register servlets against it. + """ + servlet_resource = JsonResource(self.hs) + for servlet in self.servlets: + servlet(self.hs, servlet_resource) + return { + "/_matrix/client": servlet_resource, + "/_synapse/admin": servlet_resource, + } def default_config(self): """ @@ -367,7 +382,11 @@ class HomeserverTestCase(TestCase): shorthand: bool = True, federation_auth_origin: str = None, content_is_form: bool = False, - ) -> Tuple[T, FakeChannel]: + await_result: bool = True, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the given content. @@ -385,14 +404,18 @@ class HomeserverTestCase(TestCase): content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. + await_result: whether to wait for the request to complete rendering. If + true (the default), will pump the test reactor until the the renderer + tells the channel the request is finished. + + custom_headers: (name, value) pairs to add as request headers + Returns: - Tuple[synapse.http.site.SynapseRequest, channel] + The FakeChannel object which stores the result of the request. """ - if isinstance(content, dict): - content = json.dumps(content).encode("utf8") - return make_request( self.reactor, + self.site, method, path, content, @@ -401,18 +424,10 @@ class HomeserverTestCase(TestCase): shorthand, federation_auth_origin, content_is_form, + await_result, + custom_headers, ) - def render(self, request): - """ - Render a request against the resources registered by the test class's - servlets. - - Args: - request (synapse.http.site.SynapseRequest): The request to render. - """ - render(request, self.resource, self.reactor) - def setup_test_homeserver(self, *args, **kwargs): """ Set up the test homeserver, meant to be called by the overridable @@ -505,24 +520,29 @@ class HomeserverTestCase(TestCase): return result - def register_user(self, username, password, admin=False): + def register_user( + self, + username: str, + password: str, + admin: Optional[bool] = False, + displayname: Optional[str] = None, + ) -> str: """ Register a user. Requires the Admin API be registered. Args: - username (bytes/unicode): The user part of the new user. - password (bytes/unicode): The password of the new user. - admin (bool): Whether the user should be created as an admin - or not. + username: The user part of the new user. + password: The password of the new user. + admin: Whether the user should be created as an admin or not. + displayname: The displayname of the new user. Returns: - The MXID of the new user (unicode). + The MXID of the new user. """ self.hs.config.registration_shared_secret = "shared" # Create the user - request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") - self.render(request) + channel = self.make_request("GET", "/_synapse/admin/v1/register") self.assertEqual(channel.code, 200, msg=channel.result) nonce = channel.json_body["nonce"] @@ -540,16 +560,16 @@ class HomeserverTestCase(TestCase): { "nonce": nonce, "username": username, + "displayname": displayname, "password": password, "admin": admin, "mac": want_mac, "inhibit_login": True, } ) - request, channel = self.make_request( - "POST", "/_matrix/client/r0/admin/register", body.encode("utf8") + channel = self.make_request( + "POST", "/_synapse/admin/v1/register", body.encode("utf8") ) - self.render(request) self.assertEqual(channel.code, 200, channel.json_body) user_id = channel.json_body["user_id"] @@ -565,10 +585,9 @@ class HomeserverTestCase(TestCase): if device_id: body["device_id"] = device_id - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) - self.render(request) self.assertEqual(channel.code, 200, channel.result) access_token = channel.json_body["access_token"] @@ -590,7 +609,7 @@ class HomeserverTestCase(TestCase): """ event_creator = self.hs.get_event_creation_handler() secrets = self.hs.get_secrets() - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) event, context = self.get_success( event_creator.create_event( @@ -608,7 +627,9 @@ class HomeserverTestCase(TestCase): if soft_failed: event.internal_metadata.soft_failed = True - self.get_success(event_creator.send_nonmember_event(requester, event, context)) + self.get_success( + event_creator.handle_new_client_event(requester, event, context) + ) return event.event_id @@ -632,10 +653,9 @@ class HomeserverTestCase(TestCase): """ body = {"type": "m.login.password", "user": username, "password": password} - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) - self.render(request) self.assertEqual(channel.code, 403, channel.result) def inject_room_member(self, room: str, user: str, membership: Membership) -> None: @@ -659,13 +679,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase): A federating homeserver that authenticates incoming requests as `other.example.com`. """ - def prepare(self, reactor, clock, homeserver): + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_matrix/federation"] = TestTransportLayerServer(self.hs) + return d + + +class TestTransportLayerServer(JsonResource): + """A test implementation of TransportLayerServer + + authenticates incoming requests as `other.example.com`. + """ + + def __init__(self, hs): + super().__init__(hs) + class Authenticator: def authenticate_request(self, request, content): return succeed("other.example.com") + authenticator = Authenticator() + ratelimiter = FederationRateLimiter( - clock, + hs.get_clock(), FederationRateLimitConfig( window_size=1, sleep_limit=1, @@ -674,11 +710,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase): concurrent_requests=1000, ), ) - federation_server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - return super().prepare(reactor, clock, homeserver) + federation_server.register_servlets(hs, self, authenticator, ratelimiter) def override_config(extra_config): diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py new file mode 100644
index 0000000000..dadfabd46d --- /dev/null +++ b/tests/util/caches/test_deferred_cache.py
@@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +from twisted.internet import defer + +from synapse.util.caches.deferred_cache import DeferredCache + +from tests.unittest import TestCase + + +class DeferredCacheTestCase(TestCase): + def test_empty(self): + cache = DeferredCache("test") + failed = False + try: + cache.get("foo") + except KeyError: + failed = True + + self.assertTrue(failed) + + def test_hit(self): + cache = DeferredCache("test") + cache.prefill("foo", 123) + + self.assertEquals(self.successResultOf(cache.get("foo")), 123) + + def test_hit_deferred(self): + cache = DeferredCache("test") + origin_d = defer.Deferred() + set_d = cache.set("k1", origin_d) + + # get should return an incomplete deferred + get_d = cache.get("k1") + self.assertFalse(get_d.called) + + # add a callback that will make sure that the set_d gets called before the get_d + def check1(r): + self.assertTrue(set_d.called) + return r + + # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8. + # maybe we should fix that? + # get_d.addCallback(check1) + + # now fire off all the deferreds + origin_d.callback(99) + self.assertEqual(self.successResultOf(origin_d), 99) + self.assertEqual(self.successResultOf(set_d), 99) + self.assertEqual(self.successResultOf(get_d), 99) + + def test_callbacks(self): + """Invalidation callbacks are called at the right time""" + cache = DeferredCache("test") + callbacks = set() + + # start with an entry, with a callback + cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) + + # now replace that entry with a pending result + origin_d = defer.Deferred() + set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) + + # ... and also make a get request + get_d = cache.get("k1", callback=lambda: callbacks.add("get")) + + # we don't expect the invalidation callback for the original value to have + # been called yet, even though get() will now return a different result. + # I'm not sure if that is by design or not. + self.assertEqual(callbacks, set()) + + # now fire off all the deferreds + origin_d.callback(20) + self.assertEqual(self.successResultOf(set_d), 20) + self.assertEqual(self.successResultOf(get_d), 20) + + # now the original invalidation callback should have been called, but none of + # the others + self.assertEqual(callbacks, {"prefill"}) + callbacks.clear() + + # another update should invalidate both the previous results + cache.prefill("k1", 30) + self.assertEqual(callbacks, {"set", "get"}) + + def test_set_fail(self): + cache = DeferredCache("test") + callbacks = set() + + # start with an entry, with a callback + cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) + + # now replace that entry with a pending result + origin_d = defer.Deferred() + set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) + + # ... and also make a get request + get_d = cache.get("k1", callback=lambda: callbacks.add("get")) + + # none of the callbacks should have been called yet + self.assertEqual(callbacks, set()) + + # oh noes! fails! + e = Exception("oops") + origin_d.errback(e) + self.assertIs(self.failureResultOf(set_d, Exception).value, e) + self.assertIs(self.failureResultOf(get_d, Exception).value, e) + + # the callbacks for the failed requests should have been called. + # I'm not sure if this is deliberate or not. + self.assertEqual(callbacks, {"get", "set"}) + callbacks.clear() + + # the old value should still be returned now? + get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2")) + self.assertEqual(self.successResultOf(get_d2), 10) + + # replacing the value now should run the callbacks for those requests + # which got the original result + cache.prefill("k1", 30) + self.assertEqual(callbacks, {"prefill", "get2"}) + + def test_get_immediate(self): + cache = DeferredCache("test") + d1 = defer.Deferred() + cache.set("key1", d1) + + # get_immediate should return default + v = cache.get_immediate("key1", 1) + self.assertEqual(v, 1) + + # now complete the set + d1.callback(2) + + # get_immediate should return result + v = cache.get_immediate("key1", 1) + self.assertEqual(v, 2) + + def test_invalidate(self): + cache = DeferredCache("test") + cache.prefill(("foo",), 123) + cache.invalidate(("foo",)) + + failed = False + try: + cache.get(("foo",)) + except KeyError: + failed = True + + self.assertTrue(failed) + + def test_invalidate_all(self): + cache = DeferredCache("testcache") + + callback_record = [False, False] + + def record_callback(idx): + callback_record[idx] = True + + # add a couple of pending entries + d1 = defer.Deferred() + cache.set("key1", d1, partial(record_callback, 0)) + + d2 = defer.Deferred() + cache.set("key2", d2, partial(record_callback, 1)) + + # lookup should return pending deferreds + self.assertFalse(cache.get("key1").called) + self.assertFalse(cache.get("key2").called) + + # let one of the lookups complete + d2.callback("result2") + + # now the cache will return a completed deferred + self.assertEqual(self.successResultOf(cache.get("key2")), "result2") + + # now do the invalidation + cache.invalidate_all() + + # lookup should fail + with self.assertRaises(KeyError): + cache.get("key1") + with self.assertRaises(KeyError): + cache.get("key2") + + # both callbacks should have been callbacked + self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") + self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") + + # letting the other lookup complete should do nothing + d1.callback("result1") + with self.assertRaises(KeyError): + cache.get("key1", None) + + def test_eviction(self): + cache = DeferredCache( + "test", max_entries=2, apply_cache_factor_from_config=False + ) + + cache.prefill(1, "one") + cache.prefill(2, "two") + cache.prefill(3, "three") # 1 will be evicted + + failed = False + try: + cache.get(1) + except KeyError: + failed = True + + self.assertTrue(failed) + + cache.get(2) + cache.get(3) + + def test_eviction_lru(self): + cache = DeferredCache( + "test", max_entries=2, apply_cache_factor_from_config=False + ) + + cache.prefill(1, "one") + cache.prefill(2, "two") + + # Now access 1 again, thus causing 2 to be least-recently used + cache.get(1) + + cache.prefill(3, "three") + + failed = False + try: + cache.get(2) + except KeyError: + failed = True + + self.assertTrue(failed) + + cache.get(1) + cache.get(3) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 677e925477..cf1e3203a4 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py
@@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from functools import partial +from typing import Set import mock @@ -29,60 +29,50 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, lru_cache from tests import unittest +from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) -def run_on_reactor(): - d = defer.Deferred() - reactor.callLater(0, d.callback, 0) - return make_deferred_yieldable(d) - - -class CacheTestCase(unittest.TestCase): - def test_invalidate_all(self): - cache = descriptors.Cache("testcache") - - callback_record = [False, False] - - def record_callback(idx): - callback_record[idx] = True - - # add a couple of pending entries - d1 = defer.Deferred() - cache.set("key1", d1, partial(record_callback, 0)) - - d2 = defer.Deferred() - cache.set("key2", d2, partial(record_callback, 1)) - - # lookup should return observable deferreds - self.assertFalse(cache.get("key1").has_called()) - self.assertFalse(cache.get("key2").has_called()) +class LruCacheDecoratorTestCase(unittest.TestCase): + def test_base(self): + class Cls: + def __init__(self): + self.mock = mock.Mock() - # let one of the lookups complete - d2.callback("result2") + @lru_cache() + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) - # for now at least, the cache will return real results rather than an - # observabledeferred - self.assertEqual(cache.get("key2"), "result2") + obj = Cls() + obj.mock.return_value = "fish" + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() - # now do the invalidation - cache.invalidate_all() + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, 3) + obj.mock.reset_mock() - # lookup should return none - self.assertIsNone(cache.get("key1", None)) - self.assertIsNone(cache.get("key2", None)) + # the two values should now be cached + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() - # both callbacks should have been callbacked - self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") - self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") - # letting the other lookup complete should do nothing - d1.callback("result1") - self.assertIsNone(cache.get("key1", None)) +def run_on_reactor(): + d = defer.Deferred() + reactor.callLater(0, d.callback, 0) + return make_deferred_yieldable(d) class DescriptorTestCase(unittest.TestCase): @@ -174,6 +164,57 @@ class DescriptorTestCase(unittest.TestCase): d = obj.fn(1) self.failureResultOf(d, SynapseError) + def test_cache_with_async_exception(self): + """The wrapped function returns a failure + """ + + class Cls: + result = None + call_count = 0 + + @cached() + def fn(self, arg1): + self.call_count += 1 + return self.result + + obj = Cls() + callbacks = set() # type: Set[str] + + # set off an asynchronous request + obj.result = origin_d = defer.Deferred() + + d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) + self.assertFalse(d1.called) + + # a second request should also return a deferred, but should not call the + # function itself. + d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2")) + self.assertFalse(d2.called) + self.assertEqual(obj.call_count, 1) + + # no callbacks yet + self.assertEqual(callbacks, set()) + + # the original request fails + e = Exception("bzz") + origin_d.errback(e) + + # ... which should cause the lookups to fail similarly + self.assertIs(self.failureResultOf(d1, Exception).value, e) + self.assertIs(self.failureResultOf(d2, Exception).value, e) + + # ... and the callbacks to have been, uh, called. + self.assertEqual(callbacks, {"d1", "d2"}) + + # ... leaving the cache empty + self.assertEqual(len(obj.fn.cache.cache), 0) + + # and a second call should work as normal + obj.result = defer.succeed(100) + d3 = obj.fn(1) + self.assertEqual(self.successResultOf(d3), 100) + self.assertEqual(obj.call_count, 2) + def test_cache_logcontexts(self): """Check that logcontexts are set and restored correctly when using the cache.""" @@ -354,6 +395,260 @@ class DescriptorTestCase(unittest.TestCase): d = obj.fn(1) self.failureResultOf(d, SynapseError) + def test_invalidate_cascade(self): + """Invalidations should cascade up through cache contexts""" + + class Cls: + @cached(cache_context=True) + async def func1(self, key, cache_context): + return await self.func2(key, on_invalidate=cache_context.invalidate) + + @cached(cache_context=True) + async def func2(self, key, cache_context): + return self.func3(key, on_invalidate=cache_context.invalidate) + + @lru_cache(cache_context=True) + def func3(self, key, cache_context): + self.invalidate = cache_context.invalidate + return 42 + + obj = Cls() + + top_invalidate = mock.Mock() + r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate)) + self.assertEqual(r, 42) + obj.invalidate() + top_invalidate.assert_called_once() + + +class CacheDecoratorTestCase(unittest.HomeserverTestCase): + """More tests for @cached + + The following is a set of tests that got lost in a different file for a while. + + There are probably duplicates of the tests in DescriptorTestCase. Ideally the + duplicates would be removed and the two sets of classes combined. + """ + + @defer.inlineCallbacks + def test_passthrough(self): + class A: + @cached() + def func(self, key): + return key + + a = A() + + self.assertEquals((yield a.func("foo")), "foo") + self.assertEquals((yield a.func("bar")), "bar") + + @defer.inlineCallbacks + def test_hit(self): + callcount = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + a = A() + yield a.func("foo") + + self.assertEquals(callcount[0], 1) + + self.assertEquals((yield a.func("foo")), "foo") + self.assertEquals(callcount[0], 1) + + @defer.inlineCallbacks + def test_invalidate(self): + callcount = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + a = A() + yield a.func("foo") + + self.assertEquals(callcount[0], 1) + + a.func.invalidate(("foo",)) + + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + + def test_invalidate_missing(self): + class A: + @cached() + def func(self, key): + return key + + A().func.invalidate(("what",)) + + @defer.inlineCallbacks + def test_max_entries(self): + callcount = [0] + + class A: + @cached(max_entries=10) + def func(self, key): + callcount[0] += 1 + return key + + a = A() + + for k in range(0, 12): + yield a.func(k) + + self.assertEquals(callcount[0], 12) + + # There must have been at least 2 evictions, meaning if we calculate + # all 12 values again, we must get called at least 2 more times + for k in range(0, 12): + yield a.func(k) + + self.assertTrue( + callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0]) + ) + + def test_prefill(self): + callcount = [0] + + d = defer.succeed(123) + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return d + + a = A() + + a.func.prefill(("foo",), 456) + + self.assertEquals(a.func("foo").result, 456) + self.assertEquals(callcount[0], 0) + + @defer.inlineCallbacks + def test_invalidate_context(self): + callcount = [0] + callcount2 = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + yield a.func2("foo") + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 1) + + a.func.invalidate(("foo",)) + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 1) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + @defer.inlineCallbacks + def test_eviction_context(self): + callcount = [0] + callcount2 = [0] + + class A: + @cached(max_entries=2) + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + yield a.func2("foo") + yield a.func2("foo2") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func("foo3") + + self.assertEquals(callcount[0], 3) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 4) + self.assertEquals(callcount2[0], 3) + + @defer.inlineCallbacks + def test_double_get(self): + callcount = [0] + callcount2 = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 1) + + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 1) + + yield a.func2("foo") + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 2) + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 2) + + a.func.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 3) + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 3) + class CachedListDescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 0adb2174af..a739a6aaaf 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py
@@ -19,7 +19,8 @@ from mock import Mock from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache -from .. import unittest +from tests import unittest +from tests.unittest import override_config class LruCacheTestCase(unittest.HomeserverTestCase): @@ -59,7 +60,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEquals(cache.pop("key"), None) def test_del_multi(self): - cache = LruCache(4, 2, cache_type=TreeCache) + cache = LruCache(4, keylen=2, cache_type=TreeCache) cache[("animal", "cat")] = "mew" cache[("animal", "dog")] = "woof" cache[("vehicles", "car")] = "vroom" @@ -83,6 +84,11 @@ class LruCacheTestCase(unittest.HomeserverTestCase): cache.clear() self.assertEquals(len(cache), 0) + @override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) + def test_special_size(self): + cache = LruCache(10, "mycache") + self.assertEqual(cache.max_size, 100) + class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self): @@ -160,7 +166,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): m2 = Mock() m3 = Mock() m4 = Mock() - cache = LruCache(4, 2, cache_type=TreeCache) + cache = LruCache(4, keylen=2, cache_type=TreeCache) cache.set(("a", "1"), "value", callbacks=[m1]) cache.set(("a", "2"), "value", callbacks=[m2]) diff --git a/tests/utils.py b/tests/utils.py
index 4673872f88..977eeaf6ee 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -20,12 +20,12 @@ import os import time import uuid import warnings -from inspect import getcallargs +from typing import Type from urllib import parse as urlparse from mock import Mock, patch -from twisted.internet import defer, reactor +from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error @@ -33,14 +33,13 @@ from synapse.api.room_versions import RoomVersions from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import prepare_database -from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. # @@ -88,6 +87,7 @@ def setupdb(): host=POSTGRES_HOST, password=POSTGRES_PASSWORD, ) + db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") prepare_database(db_conn, db_engine, None) db_conn.close() @@ -190,11 +190,10 @@ class TestHomeServer(HomeServer): def setup_test_homeserver( cleanup_func, name="test", - datastore=None, config=None, reactor=None, - homeserverToUse=TestHomeServer, - **kargs + homeserver_to_use: Type[HomeServer] = TestHomeServer, + **kwargs ): """ Setup a homeserver suitable for running tests against. Keyword arguments @@ -217,8 +216,8 @@ def setup_test_homeserver( config.ldap_enabled = False - if "clock" not in kargs: - kargs["clock"] = MockClock() + if "clock" not in kwargs: + kwargs["clock"] = MockClock() if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex @@ -247,7 +246,7 @@ def setup_test_homeserver( # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() - if datastore is None and isinstance(db_engine, PostgresEngine): + if isinstance(db_engine, PostgresEngine): db_conn = db_engine.module.connect( database=POSTGRES_BASE_DB, user=POSTGRES_USER, @@ -263,79 +262,68 @@ def setup_test_homeserver( cur.close() db_conn.close() - if datastore is None: - hs = homeserverToUse( - name, - config=config, - version_string="Synapse/tests", - tls_server_context_factory=Mock(), - tls_client_options_factory=Mock(), - reactor=reactor, - **kargs - ) + hs = homeserver_to_use( + name, config=config, version_string="Synapse/tests", reactor=reactor, + ) - hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": - hs.setup_master() - - if isinstance(db_engine, PostgresEngine): - database = hs.get_datastores().databases[0] - - # We need to do cleanup on PostgreSQL - def cleanup(): - import psycopg2 - - # Close all the db pools - database._db_pool.close() - - dropped = False - - # Drop the test database - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - - # Try a few times to drop the DB. Some things may hold on to the - # database for a few more seconds due to flakiness, preventing - # us from dropping it when the test is over. If we can't drop - # it, warn and move on. - for x in range(5): - try: - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() - dropped = True - except psycopg2.OperationalError as e: - warnings.warn( - "Couldn't drop old db: " + str(e), category=UserWarning - ) - time.sleep(0.5) - - cur.close() - db_conn.close() - - if not dropped: - warnings.warn("Failed to drop old DB.", category=UserWarning) - - if not LEAVE_DB: - # Register the cleanup hook - cleanup_func(cleanup) + # Install @cache_in_self attributes + for key, val in kwargs.items(): + setattr(hs, "_" + key, val) - else: - hs = homeserverToUse( - name, - datastore=datastore, - config=config, - version_string="Synapse/tests", - tls_server_context_factory=Mock(), - tls_client_options_factory=Mock(), - reactor=reactor, - **kargs - ) + # Mock TLS + hs.tls_server_context_factory = Mock() + hs.tls_client_options_factory = Mock() + + hs.setup() + if homeserver_to_use == TestHomeServer: + hs.setup_background_tasks() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] + + # We need to do cleanup on PostgreSQL + def cleanup(): + import psycopg2 + + # Close all the db pools + database._db_pool.close() + + dropped = False + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for x in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + + cur.close() + db_conn.close() + + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) # bcrypt is far too slow to be doing in unit tests # Need to let the HS build an auth handler and then mess with it @@ -351,32 +339,9 @@ def setup_test_homeserver( hs.get_auth_handler().validate_hash = validate_hash - fed = kargs.get("resource_for_federation", None) - if fed: - register_federation_servlets(hs, fed) - return hs -def register_federation_servlets(hs, resource): - federation_server.register_servlets( - hs, - resource=resource, - authenticator=federation_server.Authenticator(hs), - ratelimiter=FederationRateLimiter( - hs.get_clock(), config=hs.config.rc_federation - ), - ) - - -def get_mock_call_args(pattern_func, mock_func): - """ Return the arguments the mock function was called with interpreted - by the pattern functions argument list. - """ - invoked_args, invoked_kargs = mock_func.call_args - return getcallargs(pattern_func, *invoked_args, **invoked_kargs) - - def mock_getRawHeaders(headers=None): headers = headers if headers is not None else {} @@ -562,86 +527,6 @@ class MockClock: return d -def _format_call(args, kwargs): - return ", ".join( - ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()] - ) - - -class DeferredMockCallable: - """A callable instance that stores a set of pending call expectations and - return values for them. It allows a unit test to assert that the given set - of function calls are eventually made, by awaiting on them to be called. - """ - - def __init__(self): - self.expectations = [] - self.calls = [] - - def __call__(self, *args, **kwargs): - self.calls.append((args, kwargs)) - - if not self.expectations: - raise ValueError( - "%r has no pending calls to handle call(%s)" - % (self, _format_call(args, kwargs)) - ) - - for (call, result, d) in self.expectations: - if args == call[1] and kwargs == call[2]: - d.callback(None) - return result - - failure = AssertionError( - "Was not expecting call(%s)" % (_format_call(args, kwargs)) - ) - - for _, _, d in self.expectations: - try: - d.errback(failure) - except Exception: - pass - - raise failure - - def expect_call_and_return(self, call, result): - self.expectations.append((call, result, defer.Deferred())) - - @defer.inlineCallbacks - def await_calls(self, timeout=1000): - deferred = defer.DeferredList( - [d for _, _, d in self.expectations], fireOnOneErrback=True - ) - - timer = reactor.callLater( - timeout / 1000, - deferred.errback, - AssertionError( - "%d pending calls left: %s" - % ( - len([e for e in self.expectations if not e[2].called]), - [e for e in self.expectations if not e[2].called], - ) - ), - ) - - yield deferred - - timer.cancel() - - self.calls = [] - - def assert_had_no_calls(self): - if self.calls: - calls = self.calls - self.calls = [] - - raise AssertionError( - "Expected not to received any calls, got:\n" - + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls]) - ) - - async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room """ diff --git a/tox.ini b/tox.ini
index 4d132eff4c..8e8b495292 100644 --- a/tox.ini +++ b/tox.ini
@@ -1,5 +1,5 @@ [tox] -envlist = packaging, py35, py36, py37, py38, check_codestyle, check_isort +envlist = packaging, py35, py36, py37, py38, py39, check_codestyle, check_isort [base] extras = test @@ -7,7 +7,9 @@ deps = python-subunit junitxml coverage - coverage-enable-subprocess + + # this is pinned since it's a bit of an obscure package. + coverage-enable-subprocess==1.0 # cyptography 2.2 requires setuptools >= 18.5 # @@ -23,34 +25,37 @@ deps = # install the "enum34" dependency of cryptography. pip>=10 -setenv = - # we have a pyproject.toml, but don't want pip to use it for building. - # (otherwise we get an error about 'editable mode is not supported for - # pyproject.toml-style projects'). - PIP_USE_PEP517 = false - - PYTHONDONTWRITEBYTECODE = no_byte_code - COVERAGE_PROCESS_START = {toxinidir}/.coveragerc - [testenv] deps = {[base]deps} extras = all, test -whitelist_externals = - sh - setenv = - {[base]setenv} + # use a postgres db for tox environments with "-postgres" in the name + # (see https://tox.readthedocs.io/en/3.20.1/config.html#factors-and-factor-conditional-settings) postgres: SYNAPSE_POSTGRES = 1 + + # this is used by .coveragerc to refer to the top of our tree. TOP={toxinidir} passenv = * commands = - /usr/bin/find "{toxinidir}" -name '*.pyc' -delete - # Add this so that coverage will run on subprocesses - {envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} + # the "env" invocation enables coverage checking for sub-processes. This is + # particularly important when running trial with `-j`, since that will make + # it run tests in a subprocess, whose coverage would otherwise not be + # tracked. (It also makes an explicit `coverage run` command redundant.) + # + # (See https://coverage.readthedocs.io/en/coverage-5.3/subprocess.html. + # Note that the `coverage.process_startup()` call is done by + # `coverage-enable-subprocess`.) + # + # we use "env" rather than putting a value in `setenv` so that it is not + # inherited by other tox environments. + # + # keep this in sync with the copy in `testenv:py35-old`. + # + /usr/bin/env COVERAGE_PROCESS_START={toxinidir}/.coveragerc "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} # As of twisted 16.4, trial tries to import the tests as a package (previously # it loaded the files explicitly), which means they need to be on the @@ -85,10 +90,9 @@ deps = lxml coverage - coverage-enable-subprocess + coverage-enable-subprocess==1.0 commands = - /usr/bin/find "{toxinidir}" -name '*.pyc' -delete # Make all greater-thans equals so we test the oldest version of our direct # dependencies, but make the pyopenssl 17.0, which can work against an # OpenSSL 1.1 compiled cryptography (as older ones don't compile on Travis). @@ -97,7 +101,11 @@ commands = # Install Synapse itself. This won't update any libraries. pip install -e ".[test]" - {envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} + # we have to duplicate the command from `testenv` rather than refer to it + # as `{[testenv]commands}`, because we run on ubuntu xenial, which has + # tox 2.3.1, and https://github.com/tox-dev/tox/issues/208. + # + /usr/bin/env COVERAGE_PROCESS_START={toxinidir}/.coveragerc "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} [testenv:benchmark] deps = @@ -158,14 +166,7 @@ commands= coverage html [testenv:mypy] -skip_install = True deps = {[base]deps} - mypy==0.782 - mypy-zope -extras = all +extras = all,mypy commands = mypy - -# To find all folders that pass mypy you run: -# -# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print