summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/__init__.py7
-rw-r--r--synapse/_pydantic_compat.py64
-rwxr-xr-xsynapse/_scripts/generate_workers_map.py4
-rwxr-xr-xsynapse/_scripts/hash_password.py4
-rw-r--r--synapse/_scripts/review_recent_signups.py14
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py68
-rwxr-xr-xsynapse/_scripts/synctl.py6
-rw-r--r--synapse/api/auth/__init__.py4
-rw-r--r--synapse/api/auth/msc3861_delegated.py291
-rw-r--r--synapse/api/auth_blocking.py20
-rw-r--r--synapse/api/constants.py36
-rw-r--r--synapse/api/errors.py27
-rw-r--r--synapse/api/ratelimiting.py49
-rw-r--r--synapse/api/room_versions.py95
-rw-r--r--synapse/api/urls.py45
-rw-r--r--synapse/app/generic_worker.py26
-rw-r--r--synapse/app/homeserver.py24
-rw-r--r--synapse/app/phone_stats_home.py40
-rw-r--r--synapse/appservice/__init__.py2
-rw-r--r--synapse/appservice/scheduler.py37
-rw-r--r--synapse/config/_base.py57
-rw-r--r--synapse/config/_base.pyi17
-rw-r--r--synapse/config/_util.py10
-rw-r--r--synapse/config/appservice.py13
-rw-r--r--synapse/config/captcha.py14
-rw-r--r--synapse/config/cas.py111
-rw-r--r--synapse/config/emailconfig.py366
-rw-r--r--synapse/config/experimental.py180
-rw-r--r--synapse/config/federation.py16
-rw-r--r--synapse/config/homeserver.py8
-rw-r--r--synapse/config/jwt.py2
-rw-r--r--synapse/config/key.py64
-rw-r--r--synapse/config/logger.py27
-rw-r--r--synapse/config/oidc.py59
-rw-r--r--synapse/config/push.py14
-rw-r--r--synapse/config/ratelimiting.py12
-rw-r--r--synapse/config/redis.py27
-rw-r--r--synapse/config/registration.py37
-rw-r--r--synapse/config/repository.py6
-rw-r--r--synapse/config/room_directory.py4
-rw-r--r--synapse/config/saml2.py248
-rw-r--r--synapse/config/server.py107
-rw-r--r--synapse/config/sso.py15
-rw-r--r--synapse/config/tls.py3
-rw-r--r--synapse/config/user_directory.py3
-rw-r--r--synapse/config/user_types.py44
-rw-r--r--synapse/config/voip.py23
-rw-r--r--synapse/config/workers.py58
-rw-r--r--synapse/crypto/keyring.py2
-rw-r--r--synapse/event_auth.py49
-rw-r--r--synapse/events/__init__.py28
-rw-r--r--synapse/events/auto_accept_invites.py85
-rw-r--r--synapse/events/builder.py55
-rw-r--r--synapse/events/presence_router.py2
-rw-r--r--synapse/events/snapshot.py16
-rw-r--r--synapse/events/utils.py5
-rw-r--r--synapse/events/validator.py14
-rw-r--r--synapse/federation/__init__.py3
-rw-r--r--synapse/federation/federation_base.py46
-rw-r--r--synapse/federation/federation_client.py70
-rw-r--r--synapse/federation/federation_server.py27
-rw-r--r--synapse/federation/persistence.py2
-rw-r--r--synapse/federation/sender/__init__.py229
-rw-r--r--synapse/federation/sender/per_destination_queue.py25
-rw-r--r--synapse/federation/transport/client.py27
-rw-r--r--synapse/federation/transport/server/__init__.py4
-rw-r--r--synapse/federation/transport/server/_base.py2
-rw-r--r--synapse/federation/transport/server/federation.py7
-rw-r--r--synapse/federation/units.py28
-rw-r--r--synapse/handlers/account.py8
-rw-r--r--synapse/handlers/account_data.py4
-rw-r--r--synapse/handlers/account_validity.py141
-rw-r--r--synapse/handlers/admin.py215
-rw-r--r--synapse/handlers/appservice.py8
-rw-r--r--synapse/handlers/auth.py177
-rw-r--r--synapse/handlers/cas.py412
-rw-r--r--synapse/handlers/deactivate_account.py39
-rw-r--r--synapse/handlers/delayed_events.py545
-rw-r--r--synapse/handlers/device.py355
-rw-r--r--synapse/handlers/directory.py18
-rw-r--r--synapse/handlers/e2e_keys.py97
-rw-r--r--synapse/handlers/e2e_room_keys.py4
-rw-r--r--synapse/handlers/federation.py69
-rw-r--r--synapse/handlers/federation_event.py23
-rw-r--r--synapse/handlers/identity.py811
-rw-r--r--synapse/handlers/jwt.py16
-rw-r--r--synapse/handlers/message.py86
-rw-r--r--synapse/handlers/oidc.py78
-rw-r--r--synapse/handlers/pagination.py63
-rw-r--r--synapse/handlers/presence.py13
-rw-r--r--synapse/handlers/profile.py191
-rw-r--r--synapse/handlers/register.py131
-rw-r--r--synapse/handlers/relations.py14
-rw-r--r--synapse/handlers/room.py65
-rw-r--r--synapse/handlers/room_member.py281
-rw-r--r--synapse/handlers/room_policy.py96
-rw-r--r--synapse/handlers/room_summary.py69
-rw-r--r--synapse/handlers/saml.py524
-rw-r--r--synapse/handlers/search.py6
-rw-r--r--synapse/handlers/send_email.py230
-rw-r--r--synapse/handlers/set_password.py18
-rw-r--r--synapse/handlers/sliding_sync.py3158
-rw-r--r--synapse/handlers/sliding_sync/__init__.py1691
-rw-r--r--synapse/handlers/sliding_sync/extensions.py879
-rw-r--r--synapse/handlers/sliding_sync/room_lists.py2304
-rw-r--r--synapse/handlers/sliding_sync/store.py128
-rw-r--r--synapse/handlers/sso.py44
-rw-r--r--synapse/handlers/sync.py311
-rw-r--r--synapse/handlers/ui_auth/checkers.py102
-rw-r--r--synapse/handlers/user_directory.py21
-rw-r--r--synapse/handlers/worker_lock.py11
-rw-r--r--synapse/http/client.py42
-rw-r--r--synapse/http/matrixfederationclient.py20
-rw-r--r--synapse/http/proxy.py40
-rw-r--r--synapse/http/proxyagent.py31
-rw-r--r--synapse/http/replicationagent.py4
-rw-r--r--synapse/http/server.py19
-rw-r--r--synapse/http/servlet.py25
-rw-r--r--synapse/http/site.py43
-rw-r--r--synapse/logging/_remote.py4
-rw-r--r--synapse/logging/_terse_json.py1
-rw-r--r--synapse/logging/context.py48
-rw-r--r--synapse/logging/filter.py3
-rw-r--r--synapse/logging/opentracing.py26
-rw-r--r--synapse/logging/scopecontextmanager.py19
-rw-r--r--synapse/media/_base.py261
-rw-r--r--synapse/media/media_repository.py70
-rw-r--r--synapse/media/media_storage.py103
-rw-r--r--synapse/media/storage_provider.py5
-rw-r--r--synapse/media/thumbnailer.py67
-rw-r--r--synapse/media/url_previewer.py17
-rw-r--r--synapse/metrics/__init__.py10
-rw-r--r--synapse/metrics/background_process_metrics.py2
-rw-r--r--synapse/metrics/jemalloc.py3
-rw-r--r--synapse/module_api/__init__.py181
-rw-r--r--synapse/module_api/callbacks/__init__.py8
-rw-r--r--synapse/module_api/callbacks/media_repository_callbacks.py76
-rw-r--r--synapse/module_api/callbacks/ratelimit_callbacks.py74
-rw-r--r--synapse/module_api/callbacks/spamchecker_callbacks.py196
-rw-r--r--synapse/module_api/callbacks/third_party_event_rules_callbacks.py145
-rw-r--r--synapse/notifier.py176
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py31
-rw-r--r--synapse/push/emailpusher.py331
-rw-r--r--synapse/push/httppusher.py22
-rw-r--r--synapse/push/mailer.py1003
-rw-r--r--synapse/push/push_tools.py8
-rw-r--r--synapse/push/push_types.py4
-rw-r--r--synapse/push/pusher.py38
-rw-r--r--synapse/push/pusherpool.py7
-rw-r--r--synapse/replication/http/__init__.py4
-rw-r--r--synapse/replication/http/_base.py6
-rw-r--r--synapse/replication/http/delayed_events.py62
-rw-r--r--synapse/replication/http/federation.py4
-rw-r--r--synapse/replication/http/push.py7
-rw-r--r--synapse/replication/tcp/client.py4
-rw-r--r--synapse/replication/tcp/commands.py3
-rw-r--r--synapse/replication/tcp/handler.py4
-rw-r--r--synapse/replication/tcp/protocol.py1
-rw-r--r--synapse/replication/tcp/resource.py5
-rw-r--r--synapse/replication/tcp/streams/_base.py2
-rw-r--r--synapse/replication/tcp/streams/events.py6
-rw-r--r--synapse/rest/__init__.py9
-rw-r--r--synapse/rest/admin/__init__.py52
-rw-r--r--synapse/rest/admin/devices.py37
-rw-r--r--synapse/rest/admin/event_reports.py9
-rw-r--r--synapse/rest/admin/experimental_features.py3
-rw-r--r--synapse/rest/admin/registration_tokens.py3
-rw-r--r--synapse/rest/admin/rooms.py20
-rw-r--r--synapse/rest/admin/scheduled_tasks.py70
-rw-r--r--synapse/rest/admin/users.py231
-rw-r--r--synapse/rest/client/_base.py4
-rw-r--r--synapse/rest/client/account.py612
-rw-r--r--synapse/rest/client/account_data.py6
-rw-r--r--synapse/rest/client/account_validity.py4
-rw-r--r--synapse/rest/client/appservice_ping.py7
-rw-r--r--synapse/rest/client/auth.py21
-rw-r--r--synapse/rest/client/auth_metadata.py (renamed from synapse/rest/client/auth_issuer.py)47
-rw-r--r--synapse/rest/client/capabilities.py26
-rw-r--r--synapse/rest/client/delayed_events.py111
-rw-r--r--synapse/rest/client/devices.py88
-rw-r--r--synapse/rest/client/directory.py12
-rw-r--r--synapse/rest/client/events.py1
-rw-r--r--synapse/rest/client/keys.py46
-rw-r--r--synapse/rest/client/knock.py13
-rw-r--r--synapse/rest/client/login.py67
-rw-r--r--synapse/rest/client/media.py15
-rw-r--r--synapse/rest/client/presence.py26
-rw-r--r--synapse/rest/client/profile.py204
-rw-r--r--synapse/rest/client/pusher.py16
-rw-r--r--synapse/rest/client/receipts.py4
-rw-r--r--synapse/rest/client/register.py410
-rw-r--r--synapse/rest/client/rendezvous.py54
-rw-r--r--synapse/rest/client/reporting.py29
-rw-r--r--synapse/rest/client/room.py144
-rw-r--r--synapse/rest/client/sync.py102
-rw-r--r--synapse/rest/client/tags.py7
-rw-r--r--synapse/rest/client/transactions.py7
-rw-r--r--synapse/rest/client/versions.py19
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py16
-rw-r--r--synapse/rest/media/config_resource.py11
-rw-r--r--synapse/rest/media/upload_resource.py26
-rw-r--r--synapse/rest/synapse/client/__init__.py15
-rw-r--r--synapse/rest/synapse/client/password_reset.py129
-rw-r--r--synapse/rest/synapse/client/pick_idp.py29
-rw-r--r--synapse/rest/synapse/client/saml2/__init__.py42
-rw-r--r--synapse/rest/synapse/client/saml2/metadata_resource.py46
-rw-r--r--synapse/rest/synapse/client/saml2/response_resource.py52
-rw-r--r--synapse/rest/synapse/client/unsubscribe.py88
-rw-r--r--synapse/rest/well_known.py46
-rw-r--r--synapse/server.py64
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py4
-rw-r--r--synapse/state/__init__.py73
-rw-r--r--synapse/state/v2.py4
-rw-r--r--synapse/storage/_base.py22
-rw-r--r--synapse/storage/background_updates.py50
-rw-r--r--synapse/storage/controllers/persist_events.py52
-rw-r--r--synapse/storage/controllers/purge_events.py337
-rw-r--r--synapse/storage/controllers/state.py56
-rw-r--r--synapse/storage/database.py99
-rw-r--r--synapse/storage/databases/__init__.py10
-rw-r--r--synapse/storage/databases/main/__init__.py6
-rw-r--r--synapse/storage/databases/main/account_data.py88
-rw-r--r--synapse/storage/databases/main/cache.py72
-rw-r--r--synapse/storage/databases/main/client_ips.py33
-rw-r--r--synapse/storage/databases/main/delayed_events.py549
-rw-r--r--synapse/storage/databases/main/deviceinbox.py8
-rw-r--r--synapse/storage/databases/main/devices.py27
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py42
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py165
-rw-r--r--synapse/storage/databases/main/event_federation.py16
-rw-r--r--synapse/storage/databases/main/event_push_actions.py95
-rw-r--r--synapse/storage/databases/main/events.py1055
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py1425
-rw-r--r--synapse/storage/databases/main/events_worker.py248
-rw-r--r--synapse/storage/databases/main/media_repository.py110
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py93
-rw-r--r--synapse/storage/databases/main/profile.py346
-rw-r--r--synapse/storage/databases/main/purge_events.py50
-rw-r--r--synapse/storage/databases/main/push_rule.py1
-rw-r--r--synapse/storage/databases/main/receipts.py309
-rw-r--r--synapse/storage/databases/main/registration.py742
-rw-r--r--synapse/storage/databases/main/room.py345
-rw-r--r--synapse/storage/databases/main/roommember.py534
-rw-r--r--synapse/storage/databases/main/search.py8
-rw-r--r--synapse/storage/databases/main/sliding_sync.py603
-rw-r--r--synapse/storage/databases/main/state.py42
-rw-r--r--synapse/storage/databases/main/state_deltas.py176
-rw-r--r--synapse/storage/databases/main/stats.py8
-rw-r--r--synapse/storage/databases/main/stream.py603
-rw-r--r--synapse/storage/databases/main/tags.py70
-rw-r--r--synapse/storage/databases/main/transactions.py4
-rw-r--r--synapse/storage/databases/main/user_directory.py40
-rw-r--r--synapse/storage/databases/state/bg_updates.py17
-rw-r--r--synapse/storage/databases/state/deletion.py561
-rw-r--r--synapse/storage/databases/state/store.py172
-rw-r--r--synapse/storage/engines/_base.py5
-rw-r--r--synapse/storage/engines/postgres.py11
-rw-r--r--synapse/storage/engines/sqlite.py6
-rw-r--r--synapse/storage/invite_rule.py110
-rw-r--r--synapse/storage/prepare_database.py4
-rw-r--r--synapse/storage/roommember.py28
-rw-r--r--synapse/storage/schema/__init__.py30
-rw-r--r--synapse/storage/schema/main/delta/25/fts.py3
-rw-r--r--synapse/storage/schema/main/delta/27/ts.py3
-rw-r--r--synapse/storage/schema/main/delta/31/search_update.py3
-rw-r--r--synapse/storage/schema/main/delta/33/event_fields.py3
-rw-r--r--synapse/storage/schema/main/delta/56/unique_user_filter_index.py4
-rw-r--r--synapse/storage/schema/main/delta/61/03recreate_min_depth.py1
-rw-r--r--synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py1
-rw-r--r--synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py1
-rw-r--r--synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py1
-rw-r--r--synapse/storage/schema/main/delta/78/03event_extremities_constraints.py1
-rw-r--r--synapse/storage/schema/main/delta/86/02_receipts_event_id_index.sql15
-rw-r--r--synapse/storage/schema/main/delta/87/01_sliding_sync_memberships.sql169
-rw-r--r--synapse/storage/schema/main/delta/87/02_per_connection_state.sql81
-rw-r--r--synapse/storage/schema/main/delta/87/03_current_state_index.sql19
-rw-r--r--synapse/storage/schema/main/delta/88/01_add_delayed_events.sql43
-rw-r--r--synapse/storage/schema/main/delta/88/01_custom_profile_fields.sql15
-rw-r--r--synapse/storage/schema/main/delta/88/02_fix_sliding_sync_membership_snapshots_forgotten_column.sql21
-rw-r--r--synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql18
-rw-r--r--synapse/storage/schema/main/delta/88/04_current_state_delta_index.sql18
-rw-r--r--synapse/storage/schema/main/delta/88/05_drop_old_otks.sql.postgres19
-rw-r--r--synapse/storage/schema/main/delta/88/05_drop_old_otks.sql.sqlite19
-rw-r--r--synapse/storage/schema/main/delta/88/05_sliding_sync_room_config_index.sql20
-rw-r--r--synapse/storage/schema/main/delta/88/06_events_received_ts_index.sql17
-rw-r--r--synapse/storage/schema/main/delta/89/01_sliding_sync_membership_snapshot_index.sql15
-rw-r--r--synapse/storage/schema/main/delta/90/01_add_column_participant_room_memberships_table.sql16
-rw-r--r--synapse/storage/schema/main/delta/91/01_media_hash.sql28
-rw-r--r--synapse/storage/schema/main/delta/92/01_remove_trigger.sql.postgres16
-rw-r--r--synapse/storage/schema/main/delta/92/01_remove_trigger.sql.sqlite16
-rw-r--r--synapse/storage/schema/main/delta/92/02_remove_populate_participant_bg_update.sql17
-rw-r--r--synapse/storage/schema/main/delta/92/04_ss_membership_snapshot_idx.sql16
-rw-r--r--synapse/storage/schema/main/delta/92/05_fixup_max_depth_cap.sql17
-rw-r--r--synapse/storage/schema/state/delta/89/01_state_groups_deletion.sql39
-rw-r--r--synapse/storage/schema/state/delta/90/02_delete_unreferenced_state_groups.sql16
-rw-r--r--synapse/storage/schema/state/delta/90/03_remove_old_deletion_bg_update.sql15
-rw-r--r--synapse/storage/types.py3
-rw-r--r--synapse/synapse_rust/events.pyi28
-rw-r--r--synapse/synapse_rust/push.pyi2
-rw-r--r--synapse/types/__init__.py13
-rw-r--r--synapse/types/handlers/__init__.py364
-rw-r--r--synapse/types/handlers/policy_server.py16
-rw-r--r--synapse/types/handlers/sliding_sync.py875
-rw-r--r--synapse/types/rest/__init__.py9
-rw-r--r--synapse/types/rest/client/__init__.py75
-rw-r--r--synapse/types/state.py39
-rw-r--r--synapse/types/storage/__init__.py56
-rw-r--r--synapse/util/async_helpers.py154
-rw-r--r--synapse/util/caches/dictionary_cache.py13
-rw-r--r--synapse/util/caches/expiringcache.py3
-rw-r--r--synapse/util/caches/lrucache.py3
-rw-r--r--synapse/util/caches/response_cache.py33
-rw-r--r--synapse/util/caches/stream_change_cache.py23
-rw-r--r--synapse/util/events.py29
-rw-r--r--synapse/util/iterutils.py7
-rw-r--r--synapse/util/linked_list.py3
-rw-r--r--synapse/util/macaroons.py3
-rw-r--r--synapse/util/metrics.py15
-rw-r--r--synapse/util/msisdn.py51
-rw-r--r--synapse/util/patch_inline_callbacks.py6
-rw-r--r--synapse/util/ratelimitutils.py2
-rw-r--r--synapse/util/rust.py87
-rw-r--r--synapse/util/stringutils.py12
-rw-r--r--synapse/util/task_scheduler.py155
-rw-r--r--synapse/util/threepids.py123
-rw-r--r--synapse/util/wheel_timer.py6
-rw-r--r--synapse/visibility.py72
327 files changed, 21452 insertions, 14059 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py

index 99ed7a5374..e7784ac5d7 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -20,8 +20,7 @@ # # -""" This is an implementation of a Matrix homeserver. -""" +"""This is an implementation of a Matrix homeserver.""" import os import sys @@ -40,8 +39,8 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True # Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the # if-statement completely. py_version = sys.version_info -if py_version < (3, 8): - print("Synapse requires Python 3.8 or above.") +if py_version < (3, 9): + print("Synapse requires Python 3.9 or above.") sys.exit(1) # Allow using the asyncio reactor via env var. diff --git a/synapse/_pydantic_compat.py b/synapse/_pydantic_compat.py
index a6ceeb04d2..f0eedf5c6d 100644 --- a/synapse/_pydantic_compat.py +++ b/synapse/_pydantic_compat.py
@@ -19,6 +19,8 @@ # # +from typing import TYPE_CHECKING + from packaging.version import Version try: @@ -30,4 +32,64 @@ except ImportError: HAS_PYDANTIC_V2: bool = Version(pydantic_version).major == 2 -__all__ = ("HAS_PYDANTIC_V2",) +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import ( + BaseModel, + Extra, + Field, + MissingError, + PydanticValueError, + StrictBool, + StrictInt, + StrictStr, + ValidationError, + conbytes, + confloat, + conint, + constr, + parse_obj_as, + validator, + ) + from pydantic.v1.error_wrappers import ErrorWrapper + from pydantic.v1.typing import get_args +else: + from pydantic import ( + BaseModel, + Extra, + Field, + MissingError, + PydanticValueError, + StrictBool, + StrictInt, + StrictStr, + ValidationError, + conbytes, + confloat, + conint, + constr, + parse_obj_as, + validator, + ) + from pydantic.error_wrappers import ErrorWrapper + from pydantic.typing import get_args + +__all__ = ( + "HAS_PYDANTIC_V2", + "BaseModel", + "constr", + "conbytes", + "conint", + "confloat", + "ErrorWrapper", + "Extra", + "Field", + "get_args", + "MissingError", + "parse_obj_as", + "PydanticValueError", + "StrictBool", + "StrictInt", + "StrictStr", + "ValidationError", + "validator", +) diff --git a/synapse/_scripts/generate_workers_map.py b/synapse/_scripts/generate_workers_map.py
index 715c7ddc17..09feb8cf30 100755 --- a/synapse/_scripts/generate_workers_map.py +++ b/synapse/_scripts/generate_workers_map.py
@@ -171,7 +171,7 @@ def elide_http_methods_if_unconflicting( """ def paths_to_methods_dict( - methods_and_paths: Iterable[Tuple[str, str]] + methods_and_paths: Iterable[Tuple[str, str]], ) -> Dict[str, Set[str]]: """ Given (method, path) pairs, produces a dict from path to set of methods @@ -201,7 +201,7 @@ def elide_http_methods_if_unconflicting( def simplify_path_regexes( - registrations: Dict[Tuple[str, str], EndpointDescription] + registrations: Dict[Tuple[str, str], EndpointDescription], ) -> Dict[Tuple[str, str], EndpointDescription]: """ Simplify all the path regexes for the dict of endpoint descriptions, diff --git a/synapse/_scripts/hash_password.py b/synapse/_scripts/hash_password.py
index 3bed367be2..2b7d3585cb 100755 --- a/synapse/_scripts/hash_password.py +++ b/synapse/_scripts/hash_password.py
@@ -56,7 +56,9 @@ def main() -> None: password_pepper = password_config.get("pepper", password_pepper) password = args.password - if not password: + if not password and not sys.stdin.isatty(): + password = sys.stdin.readline().strip() + elif not password: password = prompt_for_pass() # On Python 2, make sure we decode it to Unicode before we normalise it diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py
index ad88df477a..03bd58a1a1 100644 --- a/synapse/_scripts/review_recent_signups.py +++ b/synapse/_scripts/review_recent_signups.py
@@ -40,6 +40,7 @@ from synapse.storage.engines import create_engine class ReviewConfig(RootConfig): "A config class that just pulls out the database config" + config_classes = [DatabaseConfig] @@ -73,13 +74,6 @@ def get_recent_users( user_infos = [UserInfo(user_id, creation_ts) for user_id, creation_ts in txn] for user_info in user_infos: - user_info.emails = DatabasePool.simple_select_onecol_txn( - txn, - table="user_threepids", - keyvalues={"user_id": user_info.user_id, "medium": "email"}, - retcol="address", - ) - sql = """ SELECT room_id, canonical_alias, name, join_rules FROM local_current_membership @@ -160,7 +154,11 @@ def main() -> None: with make_conn(database_config, engine, "review_recent_signups") as db_conn: # This generates a type of Cursor, not LoggingTransaction. - user_infos = get_recent_users(db_conn.cursor(), since_ms, exclude_users_with_appservice) # type: ignore[arg-type] + user_infos = get_recent_users( + db_conn.cursor(), + since_ms, # type: ignore[arg-type] + exclude_users_with_appservice, + ) for user_info in user_infos: if exclude_users_with_email and user_info.emails: diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 5c6db8118f..573c70696e 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py
@@ -42,12 +42,12 @@ from typing import ( Set, Tuple, Type, + TypedDict, TypeVar, cast, ) import yaml -from typing_extensions import TypedDict from twisted.internet import defer, reactor as reactor_ @@ -88,6 +88,7 @@ from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore +from synapse.storage.databases.main.sliding_sync import SlidingSyncStore from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.user_directory import ( @@ -127,8 +128,14 @@ BOOLEAN_COLUMNS = { "pushers": ["enabled"], "redactions": ["have_censored"], "remote_media_cache": ["authenticated"], + "room_memberships": ["participant"], "room_stats_state": ["is_federatable"], "rooms": ["is_public", "has_auth_chain_index"], + "sliding_sync_joined_rooms": ["is_encrypted"], + "sliding_sync_membership_snapshots": [ + "has_known_state", + "is_encrypted", + ], "users": ["shadow_banned", "approved", "locked", "suspended"], "un_partial_stated_event_stream": ["rejection_status_changed"], "users_who_share_rooms": ["share_private"], @@ -185,6 +192,11 @@ APPEND_ONLY_TABLES = [ IGNORED_TABLES = { + # Porting the auto generated sequence in this table is non-trivial. + # None of the entries in this list are mandatory for Synapse to keep working. + # If state group disk space is an issue after the port, the + # `mark_unreferenced_state_groups_for_deletion_bg_update` background task can be run again. + "state_groups_pending_deletion", # We don't port these tables, as they're a faff and we can regenerate # them anyway. "user_directory", @@ -210,6 +222,15 @@ IGNORED_TABLES = { } +# These background updates will not be applied upon creation of the postgres database. +IGNORED_BACKGROUND_UPDATES = { + # Reapplying this background update to the postgres database is unnecessary after + # already having waited for the SQLite database to complete all running background + # updates. + "mark_unreferenced_state_groups_for_deletion_bg_update", +} + + # Error returned by the run function. Used at the top-level part of the script to # handle errors and return codes. end_error: Optional[str] = None @@ -250,6 +271,7 @@ class Store( ReceiptsBackgroundUpdateStore, RelationsWorkerStore, EventFederationWorkerStore, + SlidingSyncStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) @@ -680,6 +702,20 @@ class Porter: # 0 means off. 1 means full. 2 means incremental. return autovacuum_setting != 0 + async def remove_ignored_background_updates_from_database(self) -> None: + def _remove_delete_unreferenced_state_groups_bg_updates( + txn: LoggingTransaction, + ) -> None: + txn.execute( + "DELETE FROM background_updates WHERE update_name = ANY(?)", + (list(IGNORED_BACKGROUND_UPDATES),), + ) + + await self.postgres_store.db_pool.runInteraction( + "remove_delete_unreferenced_state_groups_bg_updates", + _remove_delete_unreferenced_state_groups_bg_updates, + ) + async def run(self) -> None: """Ports the SQLite database to a PostgreSQL database. @@ -712,9 +748,7 @@ class Porter: return # Check if all background updates are done, abort if not. - updates_complete = ( - await self.sqlite_store.db_pool.updates.has_completed_background_updates() - ) + updates_complete = await self.sqlite_store.db_pool.updates.has_completed_background_updates() if not updates_complete: end_error = ( "Pending background updates exist in the SQLite3 database." @@ -727,6 +761,8 @@ class Porter: self.hs_config.database.get_single_database() ) + await self.remove_ignored_background_updates_from_database() + await self.run_background_updates_on_postgres() self.progress.set_state("Creating port tables") @@ -1029,7 +1065,7 @@ class Porter: def get_sent_table_size(txn: LoggingTransaction) -> int: txn.execute( - "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) + "SELECT count(*) FROM sent_transactions WHERE ts >= ?", (yesterday,) ) result = txn.fetchone() assert result is not None @@ -1090,10 +1126,10 @@ class Porter: return done, remaining + done async def _setup_state_group_id_seq(self) -> None: - curr_id: Optional[int] = ( - await self.sqlite_store.db_pool.simple_select_one_onecol( - table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True - ) + curr_id: Optional[ + int + ] = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True ) if not curr_id: @@ -1181,13 +1217,13 @@ class Porter: ) async def _setup_auth_chain_sequence(self) -> None: - curr_chain_id: Optional[int] = ( - await self.sqlite_store.db_pool.simple_select_one_onecol( - table="event_auth_chains", - keyvalues={}, - retcol="MAX(chain_id)", - allow_none=True, - ) + curr_chain_id: Optional[ + int + ] = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="event_auth_chains", + keyvalues={}, + retcol="MAX(chain_id)", + allow_none=True, ) def r(txn: LoggingTransaction) -> None: diff --git a/synapse/_scripts/synctl.py b/synapse/_scripts/synctl.py
index 688df9485c..2e2aa27a17 100755 --- a/synapse/_scripts/synctl.py +++ b/synapse/_scripts/synctl.py
@@ -292,9 +292,9 @@ def main() -> None: for key in worker_config: if key == "worker_app": # But we allow worker_app continue - assert not key.startswith( - "worker_" - ), "Main process cannot use worker_* config" + assert not key.startswith("worker_"), ( + "Main process cannot use worker_* config" + ) else: worker_pidfile = worker_config["worker_pid_file"] worker_cache_factor = worker_config.get("synctl_cache_factor") diff --git a/synapse/api/auth/__init__.py b/synapse/api/auth/__init__.py
index d5241afe73..1b801d3ad3 100644 --- a/synapse/api/auth/__init__.py +++ b/synapse/api/auth/__init__.py
@@ -18,9 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING, Optional, Tuple - -from typing_extensions import Protocol +from typing import TYPE_CHECKING, Optional, Protocol, Tuple from twisted.web.server import Request diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py
index 7361666c77..e500a06afe 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py
@@ -19,7 +19,8 @@ # # import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from urllib.parse import urlencode from authlib.oauth2 import ClientAuth @@ -38,15 +39,16 @@ from synapse.api.errors import ( HttpResponseException, InvalidClientTokenError, OAuthInsufficientScopeError, - StoreError, SynapseError, UnrecognizedRequestError, ) from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable +from synapse.logging.opentracing import active_span, force_tracing, start_active_span from synapse.types import Requester, UserID, create_requester from synapse.util import json_decoder from synapse.util.caches.cached_call import RetryOnExceptionCachedCall +from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext if TYPE_CHECKING: from synapse.rest.admin.experimental_features import ExperimentalFeature @@ -76,6 +78,61 @@ def scope_to_list(scope: str) -> List[str]: return scope.strip().split(" ") +@dataclass +class IntrospectionResult: + _inner: IntrospectionToken + + # when we retrieved this token, + # in milliseconds since the Unix epoch + retrieved_at_ms: int + + def is_active(self, now_ms: int) -> bool: + if not self._inner.get("active"): + return False + + expires_in = self._inner.get("expires_in") + if expires_in is None: + return True + if not isinstance(expires_in, int): + raise InvalidClientTokenError("token `expires_in` is not an int") + + absolute_expiry_ms = expires_in * 1000 + self.retrieved_at_ms + return now_ms < absolute_expiry_ms + + def get_scope_list(self) -> List[str]: + value = self._inner.get("scope") + if not isinstance(value, str): + return [] + return scope_to_list(value) + + def get_sub(self) -> Optional[str]: + value = self._inner.get("sub") + if not isinstance(value, str): + return None + return value + + def get_username(self) -> Optional[str]: + value = self._inner.get("username") + if not isinstance(value, str): + return None + return value + + def get_name(self) -> Optional[str]: + value = self._inner.get("name") + if not isinstance(value, str): + return None + return value + + def get_device_id(self) -> Optional[str]: + value = self._inner.get("device_id") + if value is not None and not isinstance(value, str): + raise AuthError( + 500, + "Invalid device ID in introspection result", + ) + return value + + class PrivateKeyJWTWithKid(PrivateKeyJWT): # type: ignore[misc] """An implementation of the private_key_jwt client auth method that includes a kid header. @@ -119,9 +176,39 @@ class MSC3861DelegatedAuth(BaseAuth): self._clock = hs.get_clock() self._http_client = hs.get_proxied_http_client() self._hostname = hs.hostname - self._admin_token = self._config.admin_token + self._admin_token: Callable[[], Optional[str]] = self._config.admin_token + self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users + + # # Token Introspection Cache + # This remembers what users/devices are represented by which access tokens, + # in order to reduce overall system load: + # - on Synapse (as requests are relatively expensive) + # - on the network + # - on MAS + # + # Since there is no invalidation mechanism currently, + # the entries expire after 2 minutes. + # This does mean tokens can be treated as valid by Synapse + # for longer than reality. + # + # Ideally, tokens should logically be invalidated in the following circumstances: + # - If a session logout happens. + # In this case, MAS will delete the device within Synapse + # anyway and this is good enough as an invalidation. + # - If the client refreshes their token in MAS. + # In this case, the device still exists and it's not the end of the world for + # the old access token to continue working for a short time. + self._introspection_cache: ResponseCache[str] = ResponseCache( + self._clock, + "token_introspection", + timeout_ms=120_000, + # don't log because the keys are access tokens + enable_logging=False, + ) - self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata) + self._issuer_metadata = RetryOnExceptionCachedCall[OpenIDProviderMetadata]( + self._load_metadata + ) if isinstance(auth_method, PrivateKeyJWTWithKid): # Use the JWK as the client secret when using the private_key_jwt method @@ -131,9 +218,10 @@ class MSC3861DelegatedAuth(BaseAuth): ) else: # Else use the client secret - assert self._config.client_secret, "No client_secret provided" + client_secret = self._config.client_secret() + assert client_secret, "No client_secret provided" self._client_auth = ClientAuth( - self._config.client_id, self._config.client_secret, auth_method + self._config.client_id, client_secret, auth_method ) async def _load_metadata(self) -> OpenIDProviderMetadata: @@ -145,6 +233,39 @@ class MSC3861DelegatedAuth(BaseAuth): # metadata.validate_introspection_endpoint() return metadata + async def issuer(self) -> str: + """ + Get the configured issuer + + This will use the issuer value set in the metadata, + falling back to the one set in the config if not set in the metadata + """ + metadata = await self._issuer_metadata.get() + return metadata.issuer or self._config.issuer + + async def account_management_url(self) -> Optional[str]: + """ + Get the configured account management URL + + This will discover the account management URL from the issuer if it's not set in the config + """ + if self._config.account_management_url is not None: + return self._config.account_management_url + + try: + metadata = await self._issuer_metadata.get() + return metadata.get("account_management_uri", None) + # We don't want to raise here if we can't load the metadata + except Exception: + logger.warning("Failed to load metadata:", exc_info=True) + return None + + async def auth_metadata(self) -> Dict[str, Any]: + """ + Returns the auth metadata dict + """ + return await self._issuer_metadata.get() + async def _introspection_endpoint(self) -> str: """ Returns the introspection endpoint of the issuer @@ -154,10 +275,12 @@ class MSC3861DelegatedAuth(BaseAuth): if self._config.introspection_endpoint is not None: return self._config.introspection_endpoint - metadata = await self._load_metadata() + metadata = await self._issuer_metadata.get() return metadata.get("introspection_endpoint") - async def _introspect_token(self, token: str) -> IntrospectionToken: + async def _introspect_token( + self, token: str, cache_context: ResponseCacheContext[str] + ) -> IntrospectionResult: """ Send a token to the introspection endpoint and returns the introspection response @@ -173,11 +296,16 @@ class MSC3861DelegatedAuth(BaseAuth): Returns: The introspection response """ + # By default, we shouldn't cache the result unless we know it's valid + cache_context.should_cache = False introspection_endpoint = await self._introspection_endpoint() raw_headers: Dict[str, str] = { "Content-Type": "application/x-www-form-urlencoded", "User-Agent": str(self._http_client.user_agent, "utf-8"), "Accept": "application/json", + # Tell MAS that we support reading the device ID as an explicit + # value, not encoded in the scope. This is supported by MAS 0.15+ + "X-MAS-Supports-Device-Id": "1", } args = {"token": token, "token_type_hint": "access_token"} @@ -227,7 +355,11 @@ class MSC3861DelegatedAuth(BaseAuth): "The introspection endpoint returned an invalid JSON response." ) - return IntrospectionToken(**resp) + # We had a valid response, so we can cache it + cache_context.should_cache = True + return IntrospectionResult( + IntrospectionToken(**resp), retrieved_at_ms=self._clock.time_msec() + ) async def is_server_admin(self, requester: Requester) -> bool: return "urn:synapse:admin:*" in requester.scope @@ -239,6 +371,55 @@ class MSC3861DelegatedAuth(BaseAuth): allow_expired: bool = False, allow_locked: bool = False, ) -> Requester: + """Get a registered user's ID. + + Args: + request: An HTTP request with an access_token query parameter. + allow_guest: If False, will raise an AuthError if the user making the + request is a guest. + allow_expired: If True, allow the request through even if the account + is expired, or session token lifetime has ended. Note that + /login will deliver access tokens regardless of expiration. + + Returns: + Resolves to the requester + Raises: + InvalidClientCredentialsError if no user by that token exists or the token + is invalid. + AuthError if access is denied for the user in the access token + """ + parent_span = active_span() + with start_active_span("get_user_by_req"): + requester = await self._wrapped_get_user_by_req( + request, allow_guest, allow_expired, allow_locked + ) + + if parent_span: + if requester.authenticated_entity in self._force_tracing_for_users: + # request tracing is enabled for this user, so we need to force it + # tracing on for the parent span (which will be the servlet span). + # + # It's too late for the get_user_by_req span to inherit the setting, + # so we also force it on for that. + force_tracing() + force_tracing(parent_span) + parent_span.set_tag( + "authenticated_entity", requester.authenticated_entity + ) + parent_span.set_tag("user_id", requester.user.to_string()) + if requester.device_id is not None: + parent_span.set_tag("device_id", requester.device_id) + if requester.app_service is not None: + parent_span.set_tag("appservice_id", requester.app_service.id) + return requester + + async def _wrapped_get_user_by_req( + self, + request: SynapseRequest, + allow_guest: bool = False, + allow_expired: bool = False, + allow_locked: bool = False, + ) -> Requester: access_token = self.get_access_token_from_request(request) requester = await self.get_appservice_user(request, access_token) @@ -248,7 +429,7 @@ class MSC3861DelegatedAuth(BaseAuth): requester = await self.get_user_by_access_token(access_token, allow_expired) # Do not record requests from MAS using the virtual `__oidc_admin` user. - if access_token != self._admin_token: + if access_token != self._admin_token(): await self._record_request(request, requester) if not allow_guest and requester.is_guest: @@ -289,7 +470,8 @@ class MSC3861DelegatedAuth(BaseAuth): token: str, allow_expired: bool = False, ) -> Requester: - if self._admin_token is not None and token == self._admin_token: + admin_token = self._admin_token() + if admin_token is not None and token == admin_token: # XXX: This is a temporary solution so that the admin API can be called by # the OIDC provider. This will be removed once we have OIDC client # credentials grant support in matrix-authentication-service. @@ -304,20 +486,22 @@ class MSC3861DelegatedAuth(BaseAuth): ) try: - introspection_result = await self._introspect_token(token) + introspection_result = await self._introspection_cache.wrap( + token, self._introspect_token, token, cache_context=True + ) except Exception: logger.exception("Failed to introspect token") raise SynapseError(503, "Unable to introspect the access token") - logger.info(f"Introspection result: {introspection_result!r}") + logger.debug("Introspection result: %r", introspection_result) # TODO: introspection verification should be more extensive, especially: # - verify the audience - if not introspection_result.get("active"): + if not introspection_result.is_active(self._clock.time_msec()): raise InvalidClientTokenError("Token is not active") # Let's look at the scope - scope: List[str] = scope_to_list(introspection_result.get("scope", "")) + scope: List[str] = introspection_result.get_scope_list() # Determine type of user based on presence of particular scopes has_user_scope = SCOPE_MATRIX_API in scope @@ -327,7 +511,7 @@ class MSC3861DelegatedAuth(BaseAuth): raise InvalidClientTokenError("No scope in token granting user rights") # Match via the sub claim - sub: Optional[str] = introspection_result.get("sub") + sub = introspection_result.get_sub() if sub is None: raise InvalidClientTokenError( "Invalid sub claim in the introspection result" @@ -340,29 +524,20 @@ class MSC3861DelegatedAuth(BaseAuth): # If we could not find a user via the external_id, it either does not exist, # or the external_id was never recorded - # TODO: claim mapping should be configurable - username: Optional[str] = introspection_result.get("username") - if username is None or not isinstance(username, str): + username = introspection_result.get_username() + if username is None: raise AuthError( 500, "Invalid username claim in the introspection result", ) user_id = UserID(username, self._hostname) - # First try to find a user from the username claim + # Try to find a user from the username claim user_info = await self.store.get_user_by_id(user_id=user_id.to_string()) if user_info is None: - # If the user does not exist, we should create it on the fly - # TODO: we could use SCIM to provision users ahead of time and listen - # for SCIM SET events if those ever become standard: - # https://datatracker.ietf.org/doc/html/draft-hunt-scim-notify-00 - - # TODO: claim mapping should be configurable - # If present, use the name claim as the displayname - name: Optional[str] = introspection_result.get("name") - - await self.store.register_user( - user_id=user_id.to_string(), create_profile_with_displayname=name + raise AuthError( + 500, + "User not found", ) # And record the sub as external_id @@ -372,42 +547,40 @@ class MSC3861DelegatedAuth(BaseAuth): else: user_id = UserID.from_string(user_id_str) - # Find device_ids in scope - # We only allow a single device_id in the scope, so we find them all in the - # scope list, and raise if there are more than one. The OIDC server should be - # the one enforcing valid scopes, so we raise a 500 if we find an invalid scope. - device_ids = [ - tok[len(SCOPE_MATRIX_DEVICE_PREFIX) :] - for tok in scope - if tok.startswith(SCOPE_MATRIX_DEVICE_PREFIX) - ] - - if len(device_ids) > 1: - raise AuthError( - 500, - "Multiple device IDs in scope", - ) + # MAS 0.15+ will give us the device ID as an explicit value for compatibility sessions + # If present, we get it from here, if not we get it in thee scope + device_id = introspection_result.get_device_id() + if device_id is None: + # Find device_ids in scope + # We only allow a single device_id in the scope, so we find them all in the + # scope list, and raise if there are more than one. The OIDC server should be + # the one enforcing valid scopes, so we raise a 500 if we find an invalid scope. + device_ids = [ + tok[len(SCOPE_MATRIX_DEVICE_PREFIX) :] + for tok in scope + if tok.startswith(SCOPE_MATRIX_DEVICE_PREFIX) + ] + + if len(device_ids) > 1: + raise AuthError( + 500, + "Multiple device IDs in scope", + ) + + device_id = device_ids[0] if device_ids else None - device_id = device_ids[0] if device_ids else None if device_id is not None: # Sanity check the device_id if len(device_id) > 255 or len(device_id) < 1: raise AuthError( 500, - "Invalid device ID in scope", + "Invalid device ID in introspection result", ) - # Create the device on the fly if it does not exist - try: - await self.store.get_device( - user_id=user_id.to_string(), device_id=device_id - ) - except StoreError: - await self.store.store_device( - user_id=user_id.to_string(), - device_id=device_id, - initial_device_display_name="OIDC-native client", - ) + # Make sure the device exists + await self.store.get_device( + user_id=user_id.to_string(), device_id=device_id + ) # TODO: there is a few things missing in the requester here, which still need # to be figured out, like: diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index 303c9ba03e..a56ffd58e4 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py
@@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Optional from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.errors import Codes, ResourceLimitError -from synapse.config.server import is_threepid_reserved from synapse.types import Requester if TYPE_CHECKING: @@ -43,16 +42,13 @@ class AuthBlocking: self._admin_contact = hs.config.server.admin_contact self._max_mau_value = hs.config.server.max_mau_value self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau - self._mau_limits_reserved_threepids = ( - hs.config.server.mau_limits_reserved_threepids - ) self._is_mine_server_name = hs.is_mine_server_name self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips async def check_auth_blocking( self, user_id: Optional[str] = None, - threepid: Optional[dict] = None, + threepid: Optional[str] = None, # Not used in this method, but kept for compatibility user_type: Optional[str] = None, requester: Optional[Requester] = None, ) -> None: @@ -63,12 +59,6 @@ class AuthBlocking: user_id: If present, checks for presence against existing MAU cohort - threepid: If present, checks for presence against configured - reserved threepid. Used in cases where the user is trying register - with a MAU blocked server, normally they would be rejected but their - threepid is on the reserved list. user_id and - threepid should never be set at the same time. - user_type: If present, is used to decide whether to check against certain blocking reasons like MAU. @@ -111,9 +101,8 @@ class AuthBlocking: admin_contact=self._admin_contact, limit_type=LimitBlockingTypes.HS_DISABLED, ) - if self._limit_usage_by_mau is True: - assert not (user_id and threepid) + if self._limit_usage_by_mau is True: # If the user is already part of the MAU cohort or a trial user if user_id: timestamp = await self.store.user_last_seen_monthly_active(user_id) @@ -123,11 +112,6 @@ class AuthBlocking: is_trial = await self.store.is_trial_user(user_id) if is_trial: return - elif threepid: - # If the user does not exist yet, but is signing up with a - # reserved threepid then pass auth check - if is_threepid_reserved(self._mau_limits_reserved_threepids, threepid): - return elif user_type == UserTypes.SUPPORT: # If the user does not exist yet and is of type "support", # allow registration. Support users are excluded from MAU checks. diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 7dcb1e01fd..cd2ebf2cc3 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py
@@ -29,8 +29,13 @@ from typing import Final # the max size of a (canonical-json-encoded) event MAX_PDU_SIZE = 65536 -# the "depth" field on events is limited to 2**63 - 1 -MAX_DEPTH = 2**63 - 1 +# Max/min size of ints in canonical JSON +CANONICALJSON_MAX_INT = (2**53) - 1 +CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT + +# the "depth" field on events is limited to the same as what +# canonicaljson accepts +MAX_DEPTH = CANONICALJSON_MAX_INT # the maximum length for a room alias is 255 characters MAX_ALIAS_LENGTH = 255 @@ -81,8 +86,6 @@ class RestrictedJoinRuleTypes: class LoginType: PASSWORD: Final = "m.login.password" - EMAIL_IDENTITY: Final = "m.login.email.identity" - MSISDN: Final = "m.login.msisdn" RECAPTCHA: Final = "m.login.recaptcha" TERMS: Final = "m.login.terms" SSO: Final = "m.login.sso" @@ -180,12 +183,18 @@ ServerNoticeLimitReached: Final = "m.server_notice.usage_limit_reached" class UserTypes: """Allows for user type specific behaviour. With the benefit of hindsight - 'admin' and 'guest' users should also be UserTypes. Normal users are type None + 'admin' and 'guest' users should also be UserTypes. Extra user types can be + added in the configuration. Normal users are type None or one of the extra + user types (if configured). """ SUPPORT: Final = "support" BOT: Final = "bot" - ALL_USER_TYPES: Final = (SUPPORT, BOT) + ALL_BUILTIN_USER_TYPES: Final = (SUPPORT, BOT) + """ + The user types that are built-in to Synapse. Extra user types can be + added in the configuration. + """ class RelationTypes: @@ -230,6 +239,10 @@ class EventContentFields: ROOM_NAME: Final = "name" + MEMBERSHIP: Final = "membership" + MEMBERSHIP_DISPLAYNAME: Final = "displayname" + MEMBERSHIP_AVATAR_URL: Final = "avatar_url" + # Used in m.room.guest_access events. GUEST_ACCESS: Final = "guest_access" @@ -245,6 +258,8 @@ class EventContentFields: # `m.room.encryption`` algorithm field ENCRYPTION_ALGORITHM: Final = "algorithm" + TOMBSTONE_SUCCESSOR_ROOM: Final = "replacement_room" + class EventUnsignedContentFields: """Fields found inside the 'unsigned' data on events""" @@ -269,6 +284,10 @@ class AccountDataTypes: IGNORED_USER_LIST: Final = "m.ignored_user_list" TAG: Final = "m.tag" PUSH_RULES: Final = "m.push_rules" + # MSC4155: Invite filtering + MSC4155_INVITE_PERMISSION_CONFIG: Final = ( + "org.matrix.msc4155.invite_permission_config" + ) class HistoryVisibility: @@ -314,3 +333,8 @@ class ApprovalNoticeMedium: class Direction(enum.Enum): BACKWARDS = "b" FORWARDS = "f" + + +class ProfileFields: + DISPLAYNAME: Final = "displayname" + AVATAR_URL: Final = "avatar_url" diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index e6efa7a424..a095fb195b 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py
@@ -65,11 +65,8 @@ class Codes(str, Enum): INVALID_PARAM = "M_INVALID_PARAM" TOO_LARGE = "M_TOO_LARGE" EXCLUSIVE = "M_EXCLUSIVE" - THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" - THREEPID_IN_USE = "M_THREEPID_IN_USE" - THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND" - THREEPID_DENIED = "M_THREEPID_DENIED" INVALID_USERNAME = "M_INVALID_USERNAME" + THREEPID_MEDIUM_NOT_SUPPORTED = "M_THREEPID_MEDIUM_NOT_SUPPORTED" # Kept around for throwing when 3PID is attempted SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" @@ -87,8 +84,7 @@ class Codes(str, Enum): WEAK_PASSWORD = "M_WEAK_PASSWORD" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" - # USER_LOCKED = "M_USER_LOCKED" - USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED" + USER_LOCKED = "M_USER_LOCKED" NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED" CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA" @@ -101,8 +97,9 @@ class Codes(str, Enum): # The account has been suspended on the server. # By opposition to `USER_DEACTIVATED`, this is a reversible measure # that can possibly be appealed and reverted. - # Part of MSC3823. - USER_ACCOUNT_SUSPENDED = "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + # Introduced by MSC3823 + # https://github.com/matrix-org/matrix-spec-proposals/pull/3823 + USER_ACCOUNT_SUSPENDED = "M_USER_SUSPENDED" BAD_ALIAS = "M_BAD_ALIAS" # For restricted join rules. @@ -132,6 +129,13 @@ class Codes(str, Enum): # connection. UNKNOWN_POS = "M_UNKNOWN_POS" + # Part of MSC4133 + PROFILE_TOO_LARGE = "M_PROFILE_TOO_LARGE" + KEY_TOO_LARGE = "M_KEY_TOO_LARGE" + + # Part of MSC4155 + INVITE_BLOCKED = "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED" + class CodeMessageException(RuntimeError): """An exception with integer code, a message string attributes and optional headers. @@ -573,13 +577,6 @@ class UnsupportedRoomVersionError(SynapseError): ) -class ThreepidValidationError(SynapseError): - """An error raised when there was a problem authorising an event.""" - - def __init__(self, msg: str, errcode: str = Codes.FORBIDDEN): - super().__init__(400, msg, errcode) - - class IncompatibleRoomVersionError(SynapseError): """A server is trying to join a room whose version it does not support. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index b80630c5d3..4f3bf8f770 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py
@@ -20,8 +20,7 @@ # # -from collections import OrderedDict -from typing import Hashable, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError from synapse.config.ratelimiting import RatelimitSettings @@ -29,6 +28,12 @@ from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock +if TYPE_CHECKING: + # To avoid circular imports: + from synapse.module_api.callbacks.ratelimit_callbacks import ( + RatelimitModuleApiCallbacks, + ) + class Ratelimiter: """ @@ -73,19 +78,23 @@ class Ratelimiter: store: DataStore, clock: Clock, cfg: RatelimitSettings, + ratelimit_callbacks: Optional["RatelimitModuleApiCallbacks"] = None, ): self.clock = clock self.rate_hz = cfg.per_second self.burst_count = cfg.burst_count self.store = store self._limiter_name = cfg.key + self._ratelimit_callbacks = ratelimit_callbacks - # An ordered dictionary representing the token buckets tracked by this rate + # A dictionary representing the token buckets tracked by this rate # limiter. Each entry maps a key of arbitrary type to a tuple representing: # * The number of tokens currently in the bucket, # * The time point when the bucket was last completely empty, and # * The rate_hz (leak rate) of this particular bucket. - self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict() + self.actions: Dict[Hashable, Tuple[float, float, float]] = {} + + self.clock.looping_call(self._prune_message_counts, 60 * 1000) def _get_key( self, requester: Optional[Requester], key: Optional[Hashable] @@ -164,14 +173,25 @@ class Ratelimiter: if override and not override.messages_per_second: return True, -1.0 + if requester and self._ratelimit_callbacks: + # Check if the user has a custom rate limit for this specific limiter + # as returned by the module API. + module_override = ( + await self._ratelimit_callbacks.get_ratelimit_override_for_user( + requester.user.to_string(), + self._limiter_name, + ) + ) + + if module_override: + rate_hz = module_override.per_second + burst_count = module_override.burst_count + # Override default values if set time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz burst_count = burst_count if burst_count is not None else self.burst_count - # Remove any expired entries - self._prune_message_counts(time_now_s) - # Check if there is an existing count entry for this key action_count, time_start, _ = self._get_action_counts(key, time_now_s) @@ -246,13 +266,12 @@ class Ratelimiter: action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s) self.actions[key] = (action_count + n_actions, time_start, rate_hz) - def _prune_message_counts(self, time_now_s: float) -> None: + def _prune_message_counts(self) -> None: """Remove message count entries that have not exceeded their defined rate_hz limit - - Args: - time_now_s: The current time """ + time_now_s = self.clock.time() + # We create a copy of the key list here as the dictionary is modified during # the loop for key in list(self.actions.keys()): @@ -275,6 +294,7 @@ class Ratelimiter: update: bool = True, n_actions: int = 1, _time_now_s: Optional[float] = None, + pause: Optional[float] = 0.5, ) -> None: """Checks if an action can be performed. If not, raises a LimitExceededError @@ -298,6 +318,8 @@ class Ratelimiter: at all. _time_now_s: The current time. Optional, defaults to the current time according to self.clock. Only used by tests. + pause: Time in seconds to pause when an action is being limited. Defaults to 0.5 + to stop clients from "tight-looping" on retrying their request. Raises: LimitExceededError: If an action could not be performed, along with the time in @@ -316,9 +338,8 @@ class Ratelimiter: ) if not allowed: - # We pause for a bit here to stop clients from "tight-looping" on - # retrying their request. - await self.clock.sleep(0.5) + if pause: + await self.clock.sleep(pause) raise LimitExceededError( limiter_name=self._limiter_name, diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index fbc1d58ecb..697acc25ba 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py
@@ -107,6 +107,8 @@ class RoomVersion: # support the flag. Unknown flags are ignored by the evaluator, making conditions # fail if used. msc3931_push_features: Tuple[str, ...] # values from PushRuleRoomFlag + # MSC3757: Restricting who can overwrite a state event + msc3757_enabled: bool class RoomVersions: @@ -128,6 +130,7 @@ class RoomVersions: knock_restricted_join_rule=False, enforce_int_power_levels=False, msc3931_push_features=(), + msc3757_enabled=False, ) V2 = RoomVersion( "2", @@ -147,6 +150,7 @@ class RoomVersions: knock_restricted_join_rule=False, enforce_int_power_levels=False, msc3931_push_features=(), + msc3757_enabled=False, ) V3 = RoomVersion( "3", @@ -166,6 +170,7 @@ class RoomVersions: knock_restricted_join_rule=False, enforce_int_power_levels=False, msc3931_push_features=(), + msc3757_enabled=False, ) V4 = RoomVersion( "4", @@ -185,6 +190,7 @@ class RoomVersions: knock_restricted_join_rule=False, enforce_int_power_levels=False, msc3931_push_features=(), + msc3757_enabled=False, ) V5 = RoomVersion( "5", @@ -204,6 +210,7 @@ class RoomVersions: knock_restricted_join_rule=False, enforce_int_power_levels=False, msc3931_push_features=(), + msc3757_enabled=False, ) V6 = RoomVersion( "6", @@ -223,6 +230,7 @@ class RoomVersions: knock_restricted_join_rule=False, enforce_int_power_levels=False, msc3931_push_features=(), + msc3757_enabled=False, ) V7 = RoomVersion( "7", @@ -242,6 +250,7 @@ class RoomVersions: knock_restricted_join_rule=False, enforce_int_power_levels=False, msc3931_push_features=(), + msc3757_enabled=False, ) V8 = RoomVersion( "8", @@ -261,6 +270,7 @@ class RoomVersions: knock_restricted_join_rule=False, enforce_int_power_levels=False, msc3931_push_features=(), + msc3757_enabled=False, ) V9 = RoomVersion( "9", @@ -280,6 +290,7 @@ class RoomVersions: knock_restricted_join_rule=False, enforce_int_power_levels=False, msc3931_push_features=(), + msc3757_enabled=False, ) V10 = RoomVersion( "10", @@ -299,6 +310,7 @@ class RoomVersions: knock_restricted_join_rule=True, enforce_int_power_levels=True, msc3931_push_features=(), + msc3757_enabled=False, ) MSC1767v10 = RoomVersion( # MSC1767 (Extensible Events) based on room version "10" @@ -319,6 +331,28 @@ class RoomVersions: knock_restricted_join_rule=True, enforce_int_power_levels=True, msc3931_push_features=(PushRuleRoomFlag.EXTENSIBLE_EVENTS,), + msc3757_enabled=False, + ) + MSC3757v10 = RoomVersion( + # MSC3757 (Restricting who can overwrite a state event) based on room version "10" + "org.matrix.msc3757.10", + RoomDisposition.UNSTABLE, + EventFormatVersions.ROOM_V4_PLUS, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, + msc3389_relation_redactions=False, + knock_restricted_join_rule=True, + enforce_int_power_levels=True, + msc3931_push_features=(), + msc3757_enabled=True, ) V11 = RoomVersion( "11", @@ -338,6 +372,28 @@ class RoomVersions: knock_restricted_join_rule=True, enforce_int_power_levels=True, msc3931_push_features=(), + msc3757_enabled=False, + ) + MSC3757v11 = RoomVersion( + # MSC3757 (Restricting who can overwrite a state event) based on room version "11" + "org.matrix.msc3757.11", + RoomDisposition.UNSTABLE, + EventFormatVersions.ROOM_V4_PLUS, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + implicit_room_creator=True, # Used by MSC3820 + updated_redaction_rules=True, # Used by MSC3820 + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, + msc3389_relation_redactions=False, + knock_restricted_join_rule=True, + enforce_int_power_levels=True, + msc3931_push_features=(), + msc3757_enabled=True, ) @@ -355,42 +411,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { RoomVersions.V9, RoomVersions.V10, RoomVersions.V11, - ) -} - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class RoomVersionCapability: - """An object which describes the unique attributes of a room version.""" - - identifier: str # the identifier for this capability - preferred_version: Optional[RoomVersion] - support_check_lambda: Callable[[RoomVersion], bool] - - -MSC3244_CAPABILITIES = { - cap.identifier: { - "preferred": ( - cap.preferred_version.identifier - if cap.preferred_version is not None - else None - ), - "support": [ - v.identifier - for v in KNOWN_ROOM_VERSIONS.values() - if cap.support_check_lambda(v) - ], - } - for cap in ( - RoomVersionCapability( - "knock", - RoomVersions.V7, - lambda room_version: room_version.knock_join_rule, - ), - RoomVersionCapability( - "restricted", - RoomVersions.V9, - lambda room_version: room_version.restricted_join_rule, - ), + RoomVersions.MSC3757v10, + RoomVersions.MSC3757v11, ) } diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index d077a2c613..655b5edd7a 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py
@@ -19,10 +19,12 @@ # # -"""Contains the URL paths to prefix various aspects of the server with. """ +"""Contains the URL paths to prefix various aspects of the server with.""" + import hmac from hashlib import sha256 -from urllib.parse import urlencode +from typing import Optional +from urllib.parse import urlencode, urljoin from synapse.config import ConfigError from synapse.config.homeserver import HomeServerConfig @@ -65,3 +67,42 @@ class ConsentURIBuilder: urlencode({"u": user_id, "h": mac}), ) return consent_uri + + +class LoginSSORedirectURIBuilder: + def __init__(self, hs_config: HomeServerConfig): + self._public_baseurl = hs_config.server.public_baseurl + + def build_login_sso_redirect_uri( + self, *, idp_id: Optional[str], client_redirect_url: str + ) -> str: + """Build a `/login/sso/redirect` URI for the given identity provider. + + Builds `/_matrix/client/v3/login/sso/redirect/{idpId}?redirectUrl=xxx` when `idp_id` is specified. + Otherwise, builds `/_matrix/client/v3/login/sso/redirect?redirectUrl=xxx` when `idp_id` is `None`. + + Args: + idp_id: Optional ID of the identity provider + client_redirect_url: URL to redirect the user to after login + + Returns + The URI to follow when choosing a specific identity provider. + """ + base_url = urljoin( + self._public_baseurl, + f"{CLIENT_API_PREFIX}/v3/login/sso/redirect", + ) + + serialized_query_parameters = urlencode({"redirectUrl": client_redirect_url}) + + if idp_id: + resultant_url = urljoin( + # We have to add a trailing slash to the base URL to ensure that the + # last path segment is not stripped away when joining with another path. + f"{base_url}/", + f"{idp_id}?{serialized_query_parameters}", + ) + else: + resultant_url = f"{base_url}?{serialized_query_parameters}" + + return resultant_url diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 53f1859256..75c65ccc0d 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -3,7 +3,7 @@ # # Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright 2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -51,8 +51,7 @@ from synapse.http.server import JsonResource, OptionsResource from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource -from synapse.rest import ClientRestResource -from synapse.rest.admin import register_servlets_for_media_repo +from synapse.rest import ClientRestResource, admin from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -65,6 +64,7 @@ from synapse.storage.databases.main.appservice import ( ) from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore +from synapse.storage.databases.main.delayed_events import DelayedEventsStore from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.directory import DirectoryWorkerStore @@ -98,6 +98,7 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.session import SessionStore from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.databases.main.sliding_sync import SlidingSyncStore from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.stream import StreamWorkerStore @@ -159,6 +160,8 @@ class GenericWorkerStore( SessionStore, TaskSchedulerWorkerStore, ExperimentalFeaturesStore, + SlidingSyncStore, + DelayedEventsStore, ): # Properties that multiple storage classes define. Tell mypy what the # expected type is. @@ -172,8 +175,13 @@ class GenericWorkerServer(HomeServer): def _listen_http(self, listener_config: ListenerConfig) -> None: assert listener_config.http_options is not None - # We always include a health resource. - resources: Dict[str, Resource] = {"/health": HealthResource()} + # We always include an admin resource that we populate with servlets as needed + admin_resource = JsonResource(self, canonical_json=False) + resources: Dict[str, Resource] = { + # We always include a health resource. + "/health": HealthResource(), + "/_synapse/admin": admin_resource, + } for res in listener_config.http_options.resources: for name in res.names: @@ -186,6 +194,7 @@ class GenericWorkerServer(HomeServer): resources.update(build_synapse_client_resource_tree(self)) resources["/.well-known"] = well_known_resource(self) + admin.register_servlets(self, admin_resource) elif name == "federation": resources[FEDERATION_PREFIX] = TransportLayerServer(self) @@ -195,15 +204,13 @@ class GenericWorkerServer(HomeServer): # We need to serve the admin servlets for media on the # worker. - admin_resource = JsonResource(self, canonical_json=False) - register_servlets_for_media_repo(self, admin_resource) + admin.register_servlets_for_media_repo(self, admin_resource) resources.update( { MEDIA_R0_PREFIX: media_repo, MEDIA_V3_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo, - "/_synapse/admin": admin_resource, } ) @@ -280,8 +287,7 @@ class GenericWorkerServer(HomeServer): elif listener.type == "metrics": if not self.config.metrics.enable_metrics: logger.warning( - "Metrics listener configured, but " - "enable_metrics is not True!" + "Metrics listener configured, but enable_metrics is not True!" ) else: if isinstance(listener, TCPListenerConfig): diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 2a824e8457..9b5ecf2c68 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py
@@ -54,6 +54,7 @@ from synapse.config.server import ListenerConfig, TCPListenerConfig from synapse.federation.transport.server import TransportLayerServer from synapse.http.additional_resource import AdditionalResource from synapse.http.server import ( + JsonResource, OptionsResource, RootOptionsRedirectResource, StaticResource, @@ -61,8 +62,7 @@ from synapse.http.server import ( from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource -from synapse.rest import ClientRestResource -from synapse.rest.admin import AdminRestResource +from synapse.rest import ClientRestResource, admin from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -180,24 +180,18 @@ class SynapseHomeServer(HomeServer): if compress: client_resource = gz_wrap(client_resource) + admin_resource = JsonResource(self, canonical_json=False) + admin.register_servlets(self, admin_resource) + resources.update( { CLIENT_API_PREFIX: client_resource, "/.well-known": well_known_resource(self), - "/_synapse/admin": AdminRestResource(self), + "/_synapse/admin": admin_resource, **build_synapse_client_resource_tree(self), } ) - if self.config.email.can_verify_email: - from synapse.rest.synapse.client.password_reset import ( - PasswordResetSubmitTokenResource, - ) - - resources["/_synapse/client/password_reset/email/submit_token"] = ( - PasswordResetSubmitTokenResource(self) - ) - if name == "consent": from synapse.rest.consent.consent_resource import ConsentResource @@ -286,8 +280,7 @@ class SynapseHomeServer(HomeServer): elif listener.type == "metrics": if not self.config.metrics.enable_metrics: logger.warning( - "Metrics listener configured, but " - "enable_metrics is not True!" + "Metrics listener configured, but enable_metrics is not True!" ) else: if isinstance(listener, TCPListenerConfig): @@ -349,12 +342,11 @@ def setup(config_options: List[str]) -> SynapseHomeServer: ): if ( not config.captcha.enable_registration_captcha - and not config.registration.registrations_require_3pid and not config.registration.registration_requires_token ): raise ConfigError( "You have enabled open registration without any verification. This is a known vector for " - "spam and abuse. If you would like to allow public registration, please consider adding email, " + "spam and abuse. If you would like to allow public registration, please consider adding " "captcha, or token-based verification. Otherwise this check can be removed by setting the " "`enable_registration_without_verification` config option to `true`." ) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index f602bbbeea..07870a16ee 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py
@@ -34,6 +34,22 @@ if TYPE_CHECKING: logger = logging.getLogger("synapse.app.homeserver") +ONE_MINUTE_SECONDS = 60 +ONE_HOUR_SECONDS = 60 * ONE_MINUTE_SECONDS + +MILLISECONDS_PER_SECOND = 1000 + +INITIAL_DELAY_BEFORE_FIRST_PHONE_HOME_SECONDS = 5 * ONE_MINUTE_SECONDS +""" +We wait 5 minutes to send the first set of stats as the server can be quite busy the +first few minutes +""" + +PHONE_HOME_INTERVAL_SECONDS = 3 * ONE_HOUR_SECONDS +""" +Phone home stats are sent every 3 hours +""" + # Contains the list of processes we will be monitoring # currently either 0 or 1 _stats_process: List[Tuple[int, "resource.struct_rusage"]] = [] @@ -46,10 +62,6 @@ current_mau_by_service_gauge = Gauge( ["app_service"], ) max_mau_gauge = Gauge("synapse_admin_mau_max", "MAU Limit") -registered_reserved_users_mau_gauge = Gauge( - "synapse_admin_mau_registered_reserved_users", - "Registered users with reserved threepids", -) @wrap_as_background_process("phone_stats_home") @@ -185,12 +197,14 @@ def start_phone_stats_home(hs: "HomeServer") -> None: # If you increase the loop period, the accuracy of user_daily_visits # table will decrease clock.looping_call( - hs.get_datastores().main.generate_user_daily_visits, 5 * 60 * 1000 + hs.get_datastores().main.generate_user_daily_visits, + 5 * ONE_MINUTE_SECONDS * MILLISECONDS_PER_SECOND, ) # monthly active user limiting functionality clock.looping_call( - hs.get_datastores().main.reap_monthly_active_users, 1000 * 60 * 60 + hs.get_datastores().main.reap_monthly_active_users, + ONE_HOUR_SECONDS * MILLISECONDS_PER_SECOND, ) hs.get_datastores().main.reap_monthly_active_users() @@ -198,20 +212,17 @@ def start_phone_stats_home(hs: "HomeServer") -> None: async def generate_monthly_active_users() -> None: current_mau_count = 0 current_mau_count_by_service: Mapping[str, int] = {} - reserved_users: Sized = () store = hs.get_datastores().main if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() current_mau_count_by_service = ( await store.get_monthly_active_count_by_service() ) - reserved_users = await store.get_registered_reserved_users() current_mau_gauge.set(float(current_mau_count)) for app_service, count in current_mau_count_by_service.items(): current_mau_by_service_gauge.labels(app_service).set(float(count)) - registered_reserved_users_mau_gauge.set(float(len(reserved_users))) max_mau_gauge.set(float(hs.config.server.max_mau_value)) if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: @@ -221,7 +232,12 @@ def start_phone_stats_home(hs: "HomeServer") -> None: if hs.config.metrics.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") - clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000, hs, stats) + clock.looping_call( + phone_stats_home, + PHONE_HOME_INTERVAL_SECONDS * MILLISECONDS_PER_SECOND, + hs, + stats, + ) # We need to defer this init for the cases that we daemonize # otherwise the process ID we get is that of the non-daemon process @@ -229,4 +245,6 @@ def start_phone_stats_home(hs: "HomeServer") -> None: # We wait 5 minutes to send the first set of stats as the server can # be quite busy the first few minutes - clock.call_later(5 * 60, phone_stats_home, hs, stats) + clock.call_later( + INITIAL_DELAY_BEFORE_FIRST_PHONE_HOME_SECONDS, phone_stats_home, hs, stats + ) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index a96cdbf1e7..6ee5240c4e 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py
@@ -87,6 +87,7 @@ class ApplicationService: ip_range_whitelist: Optional[IPSet] = None, supports_ephemeral: bool = False, msc3202_transaction_extensions: bool = False, + msc4190_device_management: bool = False, ): self.token = token self.url = ( @@ -100,6 +101,7 @@ class ApplicationService: self.ip_range_whitelist = ip_range_whitelist self.supports_ephemeral = supports_ephemeral self.msc3202_transaction_extensions = msc3202_transaction_extensions + self.msc4190_device_management = msc4190_device_management if "|" in self.id: raise Exception("application service ID cannot contain '|' character") diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index bec83419a2..cba08dde85 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py
@@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2015, 2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023, 2025 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -54,6 +54,7 @@ UP & quit +---------- YES SUCCESS This is all tied together by the AppServiceScheduler which DIs the required components. """ + import logging from typing import ( TYPE_CHECKING, @@ -69,6 +70,8 @@ from typing import ( Tuple, ) +from twisted.internet.interfaces import IDelayedCall + from synapse.appservice import ( ApplicationService, ApplicationServiceState, @@ -449,6 +452,20 @@ class _TransactionController: recoverer.recover() logger.info("Now %i active recoverers", len(self.recoverers)) + def force_retry(self, service: ApplicationService) -> None: + """Forces a Recoverer to attempt delivery of transations immediately. + + Args: + service: + """ + recoverer = self.recoverers.get(service.id) + if not recoverer: + # No need to force a retry on a happy AS. + logger.info(f"{service.id} is not in recovery, not forcing retry") + return + + recoverer.force_retry() + async def _is_service_up(self, service: ApplicationService) -> bool: state = await self.store.get_appservice_state(service) return state == ApplicationServiceState.UP or state is None @@ -481,11 +498,12 @@ class _Recoverer: self.service = service self.callback = callback self.backoff_counter = 1 + self.scheduled_recovery: Optional[IDelayedCall] = None def recover(self) -> None: delay = 2**self.backoff_counter logger.info("Scheduling retries on %s in %fs", self.service.id, delay) - self.clock.call_later( + self.scheduled_recovery = self.clock.call_later( delay, run_as_background_process, "as-recoverer", self.retry ) @@ -495,6 +513,21 @@ class _Recoverer: self.backoff_counter += 1 self.recover() + def force_retry(self) -> None: + """Cancels the existing timer and forces an immediate retry in the background. + + Args: + service: + """ + # Prevent the existing backoff from occuring + if self.scheduled_recovery: + self.clock.cancel_call_later(self.scheduled_recovery) + # Run a retry, which will resechedule a recovery if it fails. + run_as_background_process( + "retry", + self.retry, + ) + async def retry(self) -> None: logger.info("Starting retries on %s", self.service.id) try: diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index adce34c03a..d367d45fea 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -170,7 +170,7 @@ class Config: section: ClassVar[str] - def __init__(self, root_config: "RootConfig" = None): + def __init__(self, root_config: "RootConfig"): self.root = root_config # Get the path to the default Synapse template directory @@ -221,9 +221,13 @@ class Config: The number of milliseconds in the duration. Raises: - TypeError, if given something other than an integer or a string + TypeError: if given something other than an integer or a string, or the + duration is using an incorrect suffix. ValueError: if given a string not of the form described above. """ + # For integers, we prefer to use `type(value) is int` instead of + # `isinstance(value, int)` because we want to exclude subclasses of int, such as + # bool. if type(value) is int: # noqa: E721 return value elif isinstance(value, str): @@ -246,9 +250,20 @@ class Config: if suffix in sizes: value = value[:-1] size = sizes[suffix] + elif suffix.isdigit(): + # No suffix is treated as milliseconds. + value = value + size = 1 + else: + raise TypeError( + f"Bad duration suffix {value} (expected no suffix or one of these suffixes: {sizes.keys()})" + ) + return int(value) * size else: - raise TypeError(f"Bad duration {value!r}") + raise TypeError( + f"Bad duration type {value!r} (expected int or string duration)" + ) @staticmethod def abspath(file_path: str) -> str: @@ -430,7 +445,7 @@ class RootConfig: return res @classmethod - def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: any) -> None: + def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: """ Invoke a static function on config objects this RootConfig is configured to use. @@ -574,6 +589,14 @@ class RootConfig: " Defaults to the directory containing the last config file", ) + config_parser.add_argument( + "--no-secrets-in-config", + dest="secrets_in_config", + action="store_false", + default=True, + help="Reject config options that expect an in-line secret as value.", + ) + cls.invoke_all_static("add_arguments", config_parser) @classmethod @@ -611,7 +634,10 @@ class RootConfig: config_dict = read_config_files(config_files) obj.parse_config_dict( - config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path + config_dict, + config_dir_path=config_dir_path, + data_dir_path=data_dir_path, + allow_secrets_in_config=config_args.secrets_in_config, ) obj.invoke_all("read_arguments", config_args) @@ -638,6 +664,13 @@ class RootConfig: help="Specify config file. Can be given multiple times and" " may specify directories containing *.yaml files.", ) + parser.add_argument( + "--no-secrets-in-config", + dest="secrets_in_config", + action="store_false", + default=True, + help="Reject config options that expect an in-line secret as value.", + ) # we nest the mutually-exclusive group inside another group so that the help # text shows them in their own group. @@ -806,14 +839,21 @@ class RootConfig: return None obj.parse_config_dict( - config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path + config_dict, + config_dir_path=config_dir_path, + data_dir_path=data_dir_path, + allow_secrets_in_config=config_args.secrets_in_config, ) obj.invoke_all("read_arguments", config_args) return obj def parse_config_dict( - self, config_dict: Dict[str, Any], config_dir_path: str, data_dir_path: str + self, + config_dict: Dict[str, Any], + config_dir_path: str, + data_dir_path: str, + allow_secrets_in_config: bool = True, ) -> None: """Read the information from the config dict into this Config object. @@ -831,6 +871,7 @@ class RootConfig: config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path, + allow_secrets_in_config=allow_secrets_in_config, ) def generate_missing_files( @@ -1006,7 +1047,7 @@ class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): return self._get_instance(key) -def read_file(file_path: Any, config_path: Iterable[str]) -> str: +def read_file(file_path: Any, config_path: StrSequence) -> str: """Check the given file exists, and read it into a string If it does not, emit an error indicating the problem diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index d9cb0da38b..baac814808 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi
@@ -30,7 +30,6 @@ from synapse.config import ( # noqa: F401 cas, consent, database, - emailconfig, experimental, federation, jwt, @@ -49,7 +48,6 @@ from synapse.config import ( # noqa: F401 retention, room, room_directory, - saml2, server, server_notices, spam_checker, @@ -59,6 +57,7 @@ from synapse.config import ( # noqa: F401 tls, tracer, user_directory, + user_types, voip, workers, ) @@ -96,13 +95,10 @@ class RootConfig: api: api.ApiConfig appservice: appservice.AppServiceConfig key: key.KeyConfig - saml2: saml2.SAML2Config - cas: cas.CasConfig sso: sso.SSOConfig oidc: oidc.OIDCConfig jwt: jwt.JWTConfig auth: auth.AuthConfig - email: emailconfig.EmailConfig worker: workers.WorkerConfig authproviders: password_auth_providers.PasswordAuthProviderConfig push: push.PushConfig @@ -122,6 +118,7 @@ class RootConfig: retention: retention.RetentionConfig background_updates: background_updates.BackgroundUpdateConfig auto_accept_invites: auto_accept_invites.AutoAcceptInvitesConfig + user_types: user_types.UserTypesConfig config_classes: List[Type["Config"]] = ... config_files: List[str] @@ -132,7 +129,11 @@ class RootConfig: @classmethod def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ... def parse_config_dict( - self, config_dict: Dict[str, Any], config_dir_path: str, data_dir_path: str + self, + config_dict: Dict[str, Any], + config_dir_path: str, + data_dir_path: str, + allow_secrets_in_config: bool = ..., ) -> None: ... def generate_config( self, @@ -175,7 +176,7 @@ class RootConfig: class Config: root: RootConfig default_template_dir: str - def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ... + def __init__(self, root_config: RootConfig = ...) -> None: ... @staticmethod def parse_size(value: Union[str, int]) -> int: ... @staticmethod @@ -208,4 +209,4 @@ class ShardedWorkerHandlingConfig: class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): def get_instance(self, key: str) -> str: ... # noqa: F811 -def read_file(file_path: Any, config_path: Iterable[str]) -> str: ... +def read_file(file_path: Any, config_path: StrSequence) -> str: ... diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index 32b906a1ec..731b60a840 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py
@@ -18,17 +18,11 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar +from typing import Any, Dict, Type, TypeVar import jsonschema -from synapse._pydantic_compat import HAS_PYDANTIC_V2 - -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import BaseModel, ValidationError, parse_obj_as -else: - from pydantic import BaseModel, ValidationError, parse_obj_as - +from synapse._pydantic_compat import BaseModel, ValidationError, parse_obj_as from synapse.config._base import ConfigError from synapse.types import JsonDict, StrSequence diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 6ff00e1ff8..dda6bcd1b7 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py
@@ -183,6 +183,18 @@ def _load_appservice( "The `org.matrix.msc3202` option should be true or false if specified." ) + # Opt-in flag for the MSC4190 behaviours. + # When enabled, the following C-S API endpoints change for appservices: + # - POST /register does not return an access token + # - PUT /devices/{device_id} creates a new device if one does not exist + # - DELETE /devices/{device_id} no longer requires UIA + # - POST /delete_devices/{device_id} no longer requires UIA + msc4190_enabled = as_info.get("io.element.msc4190", False) + if not isinstance(msc4190_enabled, bool): + raise ValueError( + "The `io.element.msc4190` option should be true or false if specified." + ) + return ApplicationService( token=as_info["as_token"], url=as_info["url"], @@ -195,4 +207,5 @@ def _load_appservice( ip_range_whitelist=ip_range_whitelist, supports_ephemeral=supports_ephemeral, msc3202_transaction_extensions=msc3202_transaction_extensions, + msc4190_device_management=msc4190_enabled, ) diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index 84897c09c5..57d67abbc3 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py
@@ -29,8 +29,15 @@ from ._base import Config, ConfigError class CaptchaConfig(Config): section = "captcha" - def read_config(self, config: JsonDict, **kwargs: Any) -> None: + def read_config( + self, config: JsonDict, allow_secrets_in_config: bool, **kwargs: Any + ) -> None: recaptcha_private_key = config.get("recaptcha_private_key") + if recaptcha_private_key and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("recaptcha_private_key",), + ) if recaptcha_private_key is not None and not isinstance( recaptcha_private_key, str ): @@ -38,6 +45,11 @@ class CaptchaConfig(Config): self.recaptcha_private_key = recaptcha_private_key recaptcha_public_key = config.get("recaptcha_public_key") + if recaptcha_public_key and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("recaptcha_public_key",), + ) if recaptcha_public_key is not None and not isinstance( recaptcha_public_key, str ): diff --git a/synapse/config/cas.py b/synapse/config/cas.py deleted file mode 100644
index fa59c350c1..0000000000 --- a/synapse/config/cas.py +++ /dev/null
@@ -1,111 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -from typing import Any, List - -from synapse.config.sso import SsoAttributeRequirement -from synapse.types import JsonDict - -from ._base import Config, ConfigError -from ._util import validate_config - - -class CasConfig(Config): - """Cas Configuration - - cas_server_url: URL of CAS server - """ - - section = "cas" - - def read_config(self, config: JsonDict, **kwargs: Any) -> None: - cas_config = config.get("cas_config", None) - self.cas_enabled = cas_config and cas_config.get("enabled", True) - - if self.cas_enabled: - self.cas_server_url = cas_config["server_url"] - - # TODO Update this to a _synapse URL. - public_baseurl = self.root.server.public_baseurl - self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" - - self.cas_protocol_version = cas_config.get("protocol_version") - if ( - self.cas_protocol_version is not None - and self.cas_protocol_version not in [1, 2, 3] - ): - raise ConfigError( - "Unsupported CAS protocol version %s (only versions 1, 2, 3 are supported)" - % (self.cas_protocol_version,), - ("cas_config", "protocol_version"), - ) - self.cas_displayname_attribute = cas_config.get("displayname_attribute") - required_attributes = cas_config.get("required_attributes") or {} - self.cas_required_attributes = _parsed_required_attributes_def( - required_attributes - ) - - self.cas_enable_registration = cas_config.get("enable_registration", True) - - self.cas_allow_numeric_ids = cas_config.get("allow_numeric_ids") - self.cas_numeric_ids_prefix = cas_config.get("numeric_ids_prefix") - if ( - self.cas_numeric_ids_prefix is not None - and self.cas_numeric_ids_prefix.isalnum() is False - ): - raise ConfigError( - "Only alphanumeric characters are allowed for numeric IDs prefix", - ("cas_config", "numeric_ids_prefix"), - ) - - self.idp_name = cas_config.get("idp_name", "CAS") - self.idp_icon = cas_config.get("idp_icon") - self.idp_brand = cas_config.get("idp_brand") - - else: - self.cas_server_url = None - self.cas_service_url = None - self.cas_protocol_version = None - self.cas_displayname_attribute = None - self.cas_required_attributes = [] - self.cas_enable_registration = False - self.cas_allow_numeric_ids = False - self.cas_numeric_ids_prefix = "u" - - -# CAS uses a legacy required attributes mapping, not the one provided by -# SsoAttributeRequirement. -REQUIRED_ATTRIBUTES_SCHEMA = { - "type": "object", - "additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]}, -} - - -def _parsed_required_attributes_def( - required_attributes: Any, -) -> List[SsoAttributeRequirement]: - validate_config( - REQUIRED_ATTRIBUTES_SCHEMA, - required_attributes, - config_path=("cas_config", "required_attributes"), - ) - return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()] diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py deleted file mode 100644
index 8033fa2e52..0000000000 --- a/synapse/config/emailconfig.py +++ /dev/null
@@ -1,366 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2019 The Matrix.org Foundation C.I.C. -# Copyright 2015-2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -# This file can't be called email.py because if it is, we cannot: -import email.utils -import logging -import os -from typing import Any - -import attr - -from synapse.types import JsonDict - -from ._base import Config, ConfigError - -logger = logging.getLogger(__name__) - -MISSING_PASSWORD_RESET_CONFIG_ERROR = """\ -Password reset emails are enabled on this homeserver due to a partial -'email' block. However, the following required keys are missing: - %s -""" - -DEFAULT_SUBJECTS = { - "message_from_person_in_room": "[%(app)s] You have a message on %(app)s from %(person)s in the %(room)s room...", - "message_from_person": "[%(app)s] You have a message on %(app)s from %(person)s...", - "messages_from_person": "[%(app)s] You have messages on %(app)s from %(person)s...", - "messages_in_room": "[%(app)s] You have messages on %(app)s in the %(room)s room...", - "messages_in_room_and_others": "[%(app)s] You have messages on %(app)s in the %(room)s room and others...", - "messages_from_person_and_others": "[%(app)s] You have messages on %(app)s from %(person)s and others...", - "invite_from_person": "[%(app)s] %(person)s has invited you to chat on %(app)s...", - "invite_from_person_to_room": "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s...", - "invite_from_person_to_space": "[%(app)s] %(person)s has invited you to join the %(space)s space on %(app)s...", - "password_reset": "[%(server_name)s] Password reset", - "email_validation": "[%(server_name)s] Validate your email", - "email_already_in_use": "[%(server_name)s] Email already in use", -} - -LEGACY_TEMPLATE_DIR_WARNING = """ -This server's configuration file is using the deprecated 'template_dir' setting in the -'email' section. Support for this setting has been deprecated and will be removed in a -future version of Synapse. Server admins should instead use the new -'custom_template_directory' setting documented here: -https://element-hq.github.io/synapse/latest/templates.html ----------------------------------------------------------------------------------------""" - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class EmailSubjectConfig: - message_from_person_in_room: str - message_from_person: str - messages_from_person: str - messages_in_room: str - messages_in_room_and_others: str - messages_from_person_and_others: str - invite_from_person: str - invite_from_person_to_room: str - invite_from_person_to_space: str - password_reset: str - email_validation: str - email_already_in_use: str - - -class EmailConfig(Config): - section = "email" - - def read_config(self, config: JsonDict, **kwargs: Any) -> None: - # TODO: We should separate better the email configuration from the notification - # and account validity config. - - self.email_enable_notifs = False - - email_config = config.get("email") - if email_config is None: - email_config = {} - - self.force_tls = email_config.get("force_tls", False) - self.email_smtp_host = email_config.get("smtp_host", "localhost") - self.email_smtp_port = email_config.get( - "smtp_port", 465 if self.force_tls else 25 - ) - self.email_smtp_user = email_config.get("smtp_user", None) - self.email_smtp_pass = email_config.get("smtp_pass", None) - self.require_transport_security = email_config.get( - "require_transport_security", False - ) - self.enable_smtp_tls = email_config.get("enable_tls", True) - if self.force_tls and not self.enable_smtp_tls: - raise ConfigError("email.force_tls requires email.enable_tls to be true") - if self.require_transport_security and not self.enable_smtp_tls: - raise ConfigError( - "email.require_transport_security requires email.enable_tls to be true" - ) - - if "app_name" in email_config: - self.email_app_name = email_config["app_name"] - else: - self.email_app_name = "Matrix" - - # TODO: Rename notif_from to something more generic, or have a separate - # from for password resets, message notifications, etc? - # Currently the email section is a bit bogged down with settings for - # multiple functions. Would be good to split it out into separate - # sections and only put the common ones under email: - self.email_notif_from = email_config.get("notif_from", None) - if self.email_notif_from is not None: - # make sure it's valid - parsed = email.utils.parseaddr(self.email_notif_from) - if parsed[1] == "": - raise RuntimeError("Invalid notif_from address") - - # A user-configurable template directory - template_dir = email_config.get("template_dir") - if template_dir is not None: - logger.warning(LEGACY_TEMPLATE_DIR_WARNING) - - if isinstance(template_dir, str): - # We need an absolute path, because we change directory after starting (and - # we don't yet know what auxiliary templates like mail.css we will need). - template_dir = os.path.abspath(template_dir) - elif template_dir is not None: - # If template_dir is something other than a str or None, warn the user - raise ConfigError("Config option email.template_dir must be type str") - - self.email_enable_notifs = email_config.get("enable_notifs", False) - - if config.get("trust_identity_server_for_password_resets"): - raise ConfigError( - 'The config option "trust_identity_server_for_password_resets" ' - "is no longer supported. Please remove it from the config file." - ) - - # If we have email config settings, assume that we can verify ownership of - # email addresses. - self.can_verify_email = email_config != {} - - # Get lifetime of a validation token in milliseconds - self.email_validation_token_lifetime = self.parse_duration( - email_config.get("validation_token_lifetime", "1h") - ) - - if self.can_verify_email: - missing = [] - if not self.email_notif_from: - missing.append("email.notif_from") - - if missing: - raise ConfigError( - MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),) - ) - - # These email templates have placeholders in them, and thus must be - # parsed using a templating engine during a request - password_reset_template_html = email_config.get( - "password_reset_template_html", "password_reset.html" - ) - password_reset_template_text = email_config.get( - "password_reset_template_text", "password_reset.txt" - ) - registration_template_html = email_config.get( - "registration_template_html", "registration.html" - ) - registration_template_text = email_config.get( - "registration_template_text", "registration.txt" - ) - already_in_use_template_html = email_config.get( - "already_in_use_template_html", "already_in_use.html" - ) - already_in_use_template_text = email_config.get( - "already_in_use_template_html", "already_in_use.txt" - ) - add_threepid_template_html = email_config.get( - "add_threepid_template_html", "add_threepid.html" - ) - add_threepid_template_text = email_config.get( - "add_threepid_template_text", "add_threepid.txt" - ) - - password_reset_template_failure_html = email_config.get( - "password_reset_template_failure_html", "password_reset_failure.html" - ) - registration_template_failure_html = email_config.get( - "registration_template_failure_html", "registration_failure.html" - ) - add_threepid_template_failure_html = email_config.get( - "add_threepid_template_failure_html", "add_threepid_failure.html" - ) - - # These templates do not support any placeholder variables, so we - # will read them from disk once during setup - password_reset_template_success_html = email_config.get( - "password_reset_template_success_html", "password_reset_success.html" - ) - registration_template_success_html = email_config.get( - "registration_template_success_html", "registration_success.html" - ) - add_threepid_template_success_html = email_config.get( - "add_threepid_template_success_html", "add_threepid_success.html" - ) - - # Read all templates from disk - ( - self.email_password_reset_template_html, - self.email_password_reset_template_text, - self.email_registration_template_html, - self.email_registration_template_text, - self.email_already_in_use_template_html, - self.email_already_in_use_template_text, - self.email_add_threepid_template_html, - self.email_add_threepid_template_text, - self.email_password_reset_template_confirmation_html, - self.email_password_reset_template_failure_html, - self.email_registration_template_failure_html, - self.email_add_threepid_template_failure_html, - password_reset_template_success_html_template, - registration_template_success_html_template, - add_threepid_template_success_html_template, - ) = self.read_templates( - [ - password_reset_template_html, - password_reset_template_text, - registration_template_html, - registration_template_text, - already_in_use_template_html, - already_in_use_template_text, - add_threepid_template_html, - add_threepid_template_text, - "password_reset_confirmation.html", - password_reset_template_failure_html, - registration_template_failure_html, - add_threepid_template_failure_html, - password_reset_template_success_html, - registration_template_success_html, - add_threepid_template_success_html, - ], - ( - td - for td in ( - self.root.server.custom_template_directory, - template_dir, - ) - if td - ), # Filter out template_dir if not provided - ) - - # Render templates that do not contain any placeholders - self.email_password_reset_template_success_html_content = ( - password_reset_template_success_html_template.render() - ) - self.email_registration_template_success_html_content = ( - registration_template_success_html_template.render() - ) - self.email_add_threepid_template_success_html_content = ( - add_threepid_template_success_html_template.render() - ) - - if self.email_enable_notifs: - missing = [] - if not self.email_notif_from: - missing.append("email.notif_from") - - if missing: - raise ConfigError( - "email.enable_notifs is True but required keys are missing: %s" - % (", ".join(missing),) - ) - - notif_template_html = email_config.get( - "notif_template_html", "notif_mail.html" - ) - notif_template_text = email_config.get( - "notif_template_text", "notif_mail.txt" - ) - - ( - self.email_notif_template_html, - self.email_notif_template_text, - ) = self.read_templates( - [notif_template_html, notif_template_text], - ( - td - for td in ( - self.root.server.custom_template_directory, - template_dir, - ) - if td - ), # Filter out template_dir if not provided - ) - - self.email_notif_for_new_users = email_config.get( - "notif_for_new_users", True - ) - self.email_riot_base_url = email_config.get( - "client_base_url", email_config.get("riot_base_url", None) - ) - # The amount of time we always wait before ever emailing about a notification - # (to give the user a chance to respond to other push or notice the window) - self.notif_delay_before_mail_ms = Config.parse_duration( - email_config.get("notif_delay_before_mail", "10m") - ) - - if self.root.account_validity.account_validity_renew_by_email_enabled: - expiry_template_html = email_config.get( - "expiry_template_html", "notice_expiry.html" - ) - expiry_template_text = email_config.get( - "expiry_template_text", "notice_expiry.txt" - ) - - ( - self.account_validity_template_html, - self.account_validity_template_text, - ) = self.read_templates( - [expiry_template_html, expiry_template_text], - ( - td - for td in ( - self.root.server.custom_template_directory, - template_dir, - ) - if td - ), # Filter out template_dir if not provided - ) - - subjects_config = email_config.get("subjects", {}) - subjects = {} - - for key, default in DEFAULT_SUBJECTS.items(): - subjects[key] = subjects_config.get(key, default) - - self.email_subjects = EmailSubjectConfig(**subjects) - - # The invite client location should be a HTTP(S) URL or None. - self.invite_client_location = email_config.get("invite_client_location") or None - if self.invite_client_location: - if not isinstance(self.invite_client_location, str): - raise ConfigError( - "Config option email.invite_client_location must be type str" - ) - if not ( - self.invite_client_location.startswith("http://") - or self.invite_client_location.startswith("https://") - ): - raise ConfigError( - "Config option email.invite_client_location must be a http or https URL", - path=("email", "invite_client_location"), - ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index bae9cc8047..881aafc3f0 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -20,6 +20,7 @@ # import enum +from functools import cache from typing import TYPE_CHECKING, Any, Optional import attr @@ -27,8 +28,8 @@ import attr.validators from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config import ConfigError -from synapse.config._base import Config, RootConfig -from synapse.types import JsonDict +from synapse.config._base import Config, RootConfig, read_file +from synapse.types import JsonDict, StrSequence # Determine whether authlib is installed. try: @@ -43,6 +44,12 @@ if TYPE_CHECKING: from authlib.jose.rfc7517 import JsonWebKey +@cache +def read_secret_from_file_once(file_path: Any, config_path: StrSequence) -> str: + """Returns the memoized secret read from file.""" + return read_file(file_path, config_path).strip() + + class ClientAuthMethod(enum.Enum): """List of supported client auth methods.""" @@ -63,6 +70,40 @@ def _parse_jwks(jwks: Optional[JsonDict]) -> Optional["JsonWebKey"]: return JsonWebKey.import_key(jwks) +def _check_client_secret( + instance: "MSC3861", _attribute: attr.Attribute, _value: Optional[str] +) -> None: + if instance._client_secret and instance._client_secret_path: + raise ConfigError( + ( + "You have configured both " + "`experimental_features.msc3861.client_secret` and " + "`experimental_features.msc3861.client_secret_path`. " + "These are mutually incompatible." + ), + ("experimental", "msc3861", "client_secret"), + ) + # Check client secret can be retrieved + instance.client_secret() + + +def _check_admin_token( + instance: "MSC3861", _attribute: attr.Attribute, _value: Optional[str] +) -> None: + if instance._admin_token and instance._admin_token_path: + raise ConfigError( + ( + "You have configured both " + "`experimental_features.msc3861.admin_token` and " + "`experimental_features.msc3861.admin_token_path`. " + "These are mutually incompatible." + ), + ("experimental", "msc3861", "admin_token"), + ) + # Check client secret can be retrieved + instance.admin_token() + + @attr.s(slots=True, frozen=True) class MSC3861: """Configuration for MSC3861: Matrix architecture change to delegate authentication via OIDC""" @@ -97,15 +138,30 @@ class MSC3861: ) """The auth method used when calling the introspection endpoint.""" - client_secret: Optional[str] = attr.ib( + _client_secret: Optional[str] = attr.ib( default=None, - validator=attr.validators.optional(attr.validators.instance_of(str)), + validator=[ + attr.validators.optional(attr.validators.instance_of(str)), + _check_client_secret, + ], ) """ The client secret to use when calling the introspection endpoint, when using any of the client_secret_* client auth methods. """ + _client_secret_path: Optional[str] = attr.ib( + default=None, + validator=[ + attr.validators.optional(attr.validators.instance_of(str)), + _check_client_secret, + ], + ) + """ + Alternative to `client_secret`: allows the secret to be specified in an + external file. + """ + jwk: Optional["JsonWebKey"] = attr.ib(default=None, converter=_parse_jwks) """ The JWKS to use when calling the introspection endpoint, @@ -133,7 +189,7 @@ class MSC3861: ClientAuthMethod.CLIENT_SECRET_BASIC, ClientAuthMethod.CLIENT_SECRET_JWT, ) - and self.client_secret is None + and self.client_secret() is None ): raise ConfigError( f"A client secret must be provided when using the {value} client auth method", @@ -152,16 +208,51 @@ class MSC3861: ) """The URL of the My Account page on the OIDC Provider as per MSC2965.""" - admin_token: Optional[str] = attr.ib( + _admin_token: Optional[str] = attr.ib( default=None, - validator=attr.validators.optional(attr.validators.instance_of(str)), + validator=[ + attr.validators.optional(attr.validators.instance_of(str)), + _check_admin_token, + ], ) """ A token that should be considered as an admin token. This is used by the OIDC provider, to make admin calls to Synapse. """ - def check_config_conflicts(self, root: RootConfig) -> None: + _admin_token_path: Optional[str] = attr.ib( + default=None, + validator=[ + attr.validators.optional(attr.validators.instance_of(str)), + _check_admin_token, + ], + ) + """ + Alternative to `admin_token`: allows the secret to be specified in an + external file. + """ + + def client_secret(self) -> Optional[str]: + """Returns the secret given via `client_secret` or `client_secret_path`.""" + if self._client_secret_path: + return read_secret_from_file_once( + self._client_secret_path, + ("experimental_features", "msc3861", "client_secret_path"), + ) + return self._client_secret + + def admin_token(self) -> Optional[str]: + """Returns the admin token given via `admin_token` or `admin_token_path`.""" + if self._admin_token_path: + return read_secret_from_file_once( + self._admin_token_path, + ("experimental_features", "msc3861", "admin_token_path"), + ) + return self._admin_token + + def check_config_conflicts( + self, root: RootConfig, allow_secrets_in_config: bool + ) -> None: """Checks for any configuration conflicts with other parts of Synapse. Raises: @@ -171,6 +262,24 @@ class MSC3861: if not self.enabled: return + if self._client_secret and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("experimental", "msc3861", "client_secret"), + ) + + if self.jwk and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("experimental", "msc3861", "jwk"), + ) + + if self._admin_token and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("experimental", "msc3861", "admin_token"), + ) + if ( root.auth.password_enabled_for_reauth or root.auth.password_enabled_for_login @@ -195,8 +304,6 @@ class MSC3861: if ( root.oidc.oidc_enabled - or root.saml2.saml2_enabled - or root.cas.cas_enabled or root.jwt.jwt_enabled ): raise ConfigError("SSO cannot be enabled when OAuth delegation is enabled") @@ -236,12 +343,6 @@ class MSC3861: ("session_lifetime",), ) - if root.registration.enable_3pid_changes: - raise ConfigError( - "enable_3pid_changes cannot be enabled when OAuth delegation is enabled", - ("enable_3pid_changes",), - ) - @attr.s(auto_attribs=True, frozen=True, slots=True) class MSC3866Config: @@ -261,7 +362,9 @@ class ExperimentalConfig(Config): section = "experimental" - def read_config(self, config: JsonDict, **kwargs: Any) -> None: + def read_config( + self, config: JsonDict, allow_secrets_in_config: bool, **kwargs: Any + ) -> None: experimental = config.get("experimental_features") or {} # MSC3026 (busy presence state) @@ -288,9 +391,6 @@ class ExperimentalConfig(Config): ), ) - # MSC3244 (room version capabilities) - self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) - # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) @@ -338,8 +438,10 @@ class ExperimentalConfig(Config): # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) - # MSC3575 (Sliding Sync API endpoints) - self.msc3575_enabled: bool = experimental.get("msc3575_enabled", False) + # MSC3575 (Sliding Sync) alternate endpoints, c.f. MSC4186. + # + # This is enabled by default as a replacement for the sliding sync proxy. + self.msc3575_enabled: bool = experimental.get("msc3575_enabled", True) # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) @@ -363,11 +465,6 @@ class ExperimentalConfig(Config): # MSC3874: Filtering /messages with rel_types / not_rel_types. self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) - # MSC3886: Simple client rendezvous capability - self.msc3886_endpoint: Optional[str] = experimental.get( - "msc3886_endpoint", None - ) - # MSC3890: Remotely silence local notifications # Note: This option requires "experimental_features.msc3391_enabled" to be # set to "true", in order to communicate account data deletions to clients. @@ -408,7 +505,9 @@ class ExperimentalConfig(Config): ) from exc # Check that none of the other config options conflict with MSC3861 when enabled - self.msc3861.check_config_conflicts(self.root) + self.msc3861.check_config_conflicts( + self.root, allow_secrets_in_config=allow_secrets_in_config + ) self.msc4028_push_encrypted_events = experimental.get( "msc4028_push_encrypted_events", False @@ -439,12 +538,23 @@ class ExperimentalConfig(Config): ("experimental", "msc4108_delegation_endpoint"), ) - self.msc3823_account_suspension = experimental.get( - "msc3823_account_suspension", False - ) + # MSC4133: Custom profile fields + self.msc4133_enabled: bool = experimental.get("msc4133_enabled", False) + + # MSC4210: Remove legacy mentions + self.msc4210_enabled: bool = experimental.get("msc4210_enabled", False) + + # MSC4222: Adding `state_after` to sync v2 + self.msc4222_enabled: bool = experimental.get("msc4222_enabled", False) - # MSC4151: Report room API (Client-Server API) - self.msc4151_enabled: bool = experimental.get("msc4151_enabled", False) + # MSC4076: Add `disable_badge_count`` to pusher configuration + self.msc4076_enabled: bool = experimental.get("msc4076_enabled", False) + + # MSC4263: Preventing MXID enumeration via key queries + self.msc4263_limit_key_queries_to_users_who_share_rooms = experimental.get( + "msc4263_limit_key_queries_to_users_who_share_rooms", + False, + ) - # MSC4156: Migrate server_name to via - self.msc4156_enabled: bool = experimental.get("msc4156_enabled", False) + # MSC4155: Invite filtering + self.msc4155_enabled: bool = experimental.get("msc4155_enabled", False) diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index cf29fa2562..31f46e420d 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py
@@ -94,5 +94,21 @@ class FederationConfig(Config): 2**62, ) + def is_domain_allowed_according_to_federation_whitelist(self, domain: str) -> bool: + """ + Returns whether a domain is allowed according to the federation whitelist. If a + federation whitelist is not set, all domains are allowed. + + Args: + domain: The domain to test. + + Returns: + True if the domain is allowed or if a whitelist is not set, False otherwise. + """ + if self.federation_domain_whitelist is None: + return True + + return domain in self.federation_domain_whitelist + _METRICS_FOR_DOMAINS_SCHEMA = {"type": "array", "items": {"type": "string"}} diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index e36c0bd6ae..969261cb11 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py
@@ -27,10 +27,8 @@ from .auto_accept_invites import AutoAcceptInvitesConfig from .background_updates import BackgroundUpdateConfig from .cache import CacheConfig from .captcha import CaptchaConfig -from .cas import CasConfig from .consent import ConsentConfig from .database import DatabaseConfig -from .emailconfig import EmailConfig from .experimental import ExperimentalConfig from .federation import FederationConfig from .jwt import JWTConfig @@ -49,7 +47,6 @@ from .repository import ContentRepositoryConfig from .retention import RetentionConfig from .room import RoomConfig from .room_directory import RoomDirectoryConfig -from .saml2 import SAML2Config from .server import ServerConfig from .server_notices import ServerNoticesConfig from .spam_checker import SpamCheckerConfig @@ -59,6 +56,7 @@ from .third_party_event_rules import ThirdPartyRulesConfig from .tls import TlsConfig from .tracer import TracerConfig from .user_directory import UserDirectoryConfig +from .user_types import UserTypesConfig from .voip import VoipConfig from .workers import WorkerConfig @@ -84,13 +82,10 @@ class HomeServerConfig(RootConfig): ApiConfig, AppServiceConfig, KeyConfig, - SAML2Config, OIDCConfig, - CasConfig, SSOConfig, JWTConfig, AuthConfig, - EmailConfig, PasswordAuthProviderConfig, PushConfig, SpamCheckerConfig, @@ -107,4 +102,5 @@ class HomeServerConfig(RootConfig): ExperimentalConfig, BackgroundUpdateConfig, AutoAcceptInvitesConfig, + UserTypesConfig, ] diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index b41f2dc08f..5c76551f33 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt.py
@@ -38,6 +38,7 @@ class JWTConfig(Config): self.jwt_algorithm = jwt_config["algorithm"] self.jwt_subject_claim = jwt_config.get("subject_claim", "sub") + self.jwt_display_name_claim = jwt_config.get("display_name_claim") # The issuer and audiences are optional, if provided, it is asserted # that the claims exist on the JWT. @@ -49,5 +50,6 @@ class JWTConfig(Config): self.jwt_secret = None self.jwt_algorithm = None self.jwt_subject_claim = None + self.jwt_display_name_claim = None self.jwt_issuer = None self.jwt_audiences = None diff --git a/synapse/config/key.py b/synapse/config/key.py
index b9925a52d2..29c558448b 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py
@@ -43,7 +43,7 @@ from unpaddedbase64 import decode_base64 from synapse.types import JsonDict from synapse.util.stringutils import random_string, random_string_with_symbols -from ._base import Config, ConfigError +from ._base import Config, ConfigError, read_file if TYPE_CHECKING: from signedjson.key import VerifyKeyWithExpiry @@ -91,6 +91,16 @@ To suppress this warning and continue using 'matrix.org', admins should set 'suppress_key_server_warning' to 'true' in homeserver.yaml. --------------------------------------------------------------------------------""" +CONFLICTING_MACAROON_SECRET_KEY_OPTS_ERROR = """\ +Conflicting options 'macaroon_secret_key' and 'macaroon_secret_key_path' are +both defined in config file. +""" + +CONFLICTING_FORM_SECRET_OPTS_ERROR = """\ +Conflicting options 'form_secret' and 'form_secret_path' are both defined in +config file. +""" + logger = logging.getLogger(__name__) @@ -107,7 +117,11 @@ class KeyConfig(Config): section = "key" def read_config( - self, config: JsonDict, config_dir_path: str, **kwargs: Any + self, + config: JsonDict, + config_dir_path: str, + allow_secrets_in_config: bool, + **kwargs: Any, ) -> None: # the signing key can be specified inline or in a separate file if "signing_key" in config: @@ -166,10 +180,21 @@ class KeyConfig(Config): ) ) - macaroon_secret_key: Optional[str] = config.get( - "macaroon_secret_key", self.root.registration.registration_shared_secret - ) - + macaroon_secret_key = config.get("macaroon_secret_key") + if macaroon_secret_key and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("macaroon_secret_key",), + ) + macaroon_secret_key_path = config.get("macaroon_secret_key_path") + if macaroon_secret_key_path: + if macaroon_secret_key: + raise ConfigError(CONFLICTING_MACAROON_SECRET_KEY_OPTS_ERROR) + macaroon_secret_key = read_file( + macaroon_secret_key_path, ("macaroon_secret_key_path",) + ).strip() + if not macaroon_secret_key: + macaroon_secret_key = self.root.registration.registration_shared_secret if not macaroon_secret_key: # Unfortunately, there are people out there that don't have this # set. Lets just be "nice" and derive one from their secret key. @@ -181,7 +206,21 @@ class KeyConfig(Config): # a secret which is used to calculate HMACs for form values, to stop # falsification of values - self.form_secret = config.get("form_secret", None) + form_secret = config.get("form_secret", None) + if form_secret and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("form_secret",), + ) + form_secret_path = config.get("form_secret_path", None) + if form_secret_path: + if form_secret: + raise ConfigError(CONFLICTING_FORM_SECRET_OPTS_ERROR) + self.form_secret = read_file( + form_secret_path, ("form_secret_path",) + ).strip() + else: + self.form_secret = form_secret def generate_config_section( self, @@ -200,16 +239,13 @@ class KeyConfig(Config): ) form_secret = 'form_secret: "%s"' % random_string_with_symbols(50) - return ( - """\ + return """\ %(macaroon_secret_key)s %(form_secret)s signing_key_path: "%(base_key_name)s.signing.key" trusted_key_servers: - server_name: "matrix.org" - """ - % locals() - ) + """ % locals() def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]: """Read the signing keys in the given path. @@ -249,7 +285,9 @@ class KeyConfig(Config): if is_signing_algorithm_supported(key_id): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) - verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(key_id, key_bytes) # type: ignore[assignment] + verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes( + key_id, key_bytes + ) # type: ignore[assignment] verify_key.expired = key_data["expired_ts"] keys[key_id] = verify_key else: diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index fca0b08d6d..110ff75c63 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py
@@ -132,24 +132,11 @@ disable_existing_loggers: false """ ) -LOG_FILE_ERROR = """\ -Support for the log_file configuration option and --log-file command-line option was -removed in Synapse 1.3.0. You should instead set up a separate log configuration file. -""" - -STRUCTURED_ERROR = """\ -Support for the structured configuration option was removed in Synapse 1.54.0. -You should instead use the standard logging configuration. See -https://element-hq.github.io/synapse/v1.54/structured_logging.html -""" - class LoggingConfig(Config): section = "logging" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - if config.get("log_file"): - raise ConfigError(LOG_FILE_ERROR) self.log_config = self.abspath(config.get("log_config")) self.no_redirect_stdio = config.get("no_redirect_stdio", False) @@ -157,18 +144,13 @@ class LoggingConfig(Config): self, config_dir_path: str, server_name: str, **kwargs: Any ) -> str: log_config = os.path.join(config_dir_path, server_name + ".log.config") - return ( - """\ + return """\ log_config: "%(log_config)s" - """ - % locals() - ) + """ % locals() def read_arguments(self, args: argparse.Namespace) -> None: if args.no_redirect_stdio is not None: self.no_redirect_stdio = args.no_redirect_stdio - if args.log_file is not None: - raise ConfigError(LOG_FILE_ERROR) @staticmethod def add_arguments(parser: argparse.ArgumentParser) -> None: @@ -296,10 +278,6 @@ def _load_logging_config(log_config_path: str) -> None: if not log_config: logging.warning("Loaded a blank logging config?") - # If the old structured logging configuration is being used, raise an error. - if "structured" in log_config and log_config.get("structured"): - raise ConfigError(STRUCTURED_ERROR) - logging.config.dictConfig(log_config) # Blow away the pyo3-log cache so that it reloads the configuration. @@ -363,5 +341,6 @@ def setup_logging( "Licensed under the AGPL 3.0 license. Website: https://github.com/element-hq/synapse" ) logging.info("Server hostname: %s", config.server.server_name) + logging.info("Public Base URL: %s", config.server.public_baseurl) logging.info("Instance name: %s", hs.get_instance_name()) logging.info("Twisted reactor: %s", type(reactor).__name__) diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index d0a03baf55..b18654ff6a 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py
@@ -125,6 +125,10 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "enum": ["client_secret_basic", "client_secret_post", "none"], }, "pkce_method": {"type": "string", "enum": ["auto", "always", "never"]}, + "id_token_signing_alg_values_supported": { + "type": "array", + "items": {"type": "string"}, + }, "scopes": {"type": "array", "items": {"type": "string"}}, "authorization_endpoint": {"type": "string"}, "token_endpoint": {"type": "string"}, @@ -137,6 +141,9 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "type": "string", "enum": ["auto", "userinfo_endpoint"], }, + "redirect_uri": { + "type": ["string", "null"], + }, "allow_existing_users": {"type": "boolean"}, "user_mapping_provider": {"type": ["object", "null"]}, "attribute_requirements": { @@ -248,7 +255,7 @@ def _parse_oidc_config_dict( idp_id = oidc_config.get("idp_id", "oidc") # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid - # clashes with other mechs (such as SAML, CAS). + # clashes with other mechs). # # We allow "oidc" as an exception so that people migrating from old-style # "oidc_config" format (which has long used "oidc" as its idp_id) can migrate to @@ -326,6 +333,9 @@ def _parse_oidc_config_dict( client_secret_jwt_key=client_secret_jwt_key, client_auth_method=client_auth_method, pkce_method=oidc_config.get("pkce_method", "auto"), + id_token_signing_alg_values_supported=oidc_config.get( + "id_token_signing_alg_values_supported" + ), scopes=oidc_config.get("scopes", ["openid"]), authorization_endpoint=oidc_config.get("authorization_endpoint"), token_endpoint=oidc_config.get("token_endpoint"), @@ -337,6 +347,7 @@ def _parse_oidc_config_dict( ), skip_verification=oidc_config.get("skip_verification", False), user_profile_method=oidc_config.get("user_profile_method", "auto"), + redirect_uri=oidc_config.get("redirect_uri"), allow_existing_users=oidc_config.get("allow_existing_users", False), user_mapping_provider_class=user_mapping_provider_class, user_mapping_provider_config=user_mapping_provider_config, @@ -345,6 +356,9 @@ def _parse_oidc_config_dict( additional_authorization_parameters=oidc_config.get( "additional_authorization_parameters", {} ), + passthrough_authorization_parameters=oidc_config.get( + "passthrough_authorization_parameters", [] + ), ) @@ -402,6 +416,34 @@ class OidcProviderConfig: # Valid values are 'auto', 'always', and 'never'. pkce_method: str + id_token_signing_alg_values_supported: Optional[List[str]] + """ + List of the JWS signing algorithms (`alg` values) that are supported for signing the + `id_token`. + + This is *not* required if `discovery` is disabled. We default to supporting `RS256` + in the downstream usage if no algorithms are configured here or in the discovery + document. + + According to the spec, the algorithm `"RS256"` MUST be included. The absolute rigid + approach would be to reject this provider as non-compliant if it's not included but + we can just allow whatever and see what happens (they're the ones that configured + the value and cooperating with the identity provider). It wouldn't be wise to add it + ourselves because absence of `RS256` might indicate that the provider actually + doesn't support it, despite the spec requirement. Adding it silently could lead to + failed authentication attempts or strange mismatch attacks. + + The `alg` value `"none"` MAY be supported but can only be used if the Authorization + Endpoint does not include `id_token` in the `response_type` (ex. + `/authorize?response_type=code` where `none` can apply, + `/authorize?response_type=code%20id_token` where `none` can't apply) (such as when + using the Authorization Code Flow). + + Spec: + - https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata + - https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationExamples + """ + # list of scopes to request scopes: Collection[str] @@ -432,6 +474,18 @@ class OidcProviderConfig: # values are: "auto" or "userinfo_endpoint". user_profile_method: str + redirect_uri: Optional[str] + """ + An optional replacement for Synapse's hardcoded `redirect_uri` URL + (`<public_baseurl>/_synapse/client/oidc/callback`). This can be used to send + the client to a different URL after it receives a response from the + `authorization_endpoint`. + + If this is set, the client is expected to call Synapse's OIDC callback URL + reproduced above itself with the necessary parameters and session cookie, in + order to complete OIDC login. + """ + # whether to allow a user logging in via OIDC to match a pre-existing account # instead of failing allow_existing_users: bool @@ -450,3 +504,6 @@ class OidcProviderConfig: # Additional parameters that will be passed to the authorization grant URL additional_authorization_parameters: Mapping[str, str] + + # Allow query parameters to the redirect endpoint that will be passed to the authorization grant URL + passthrough_authorization_parameters: Collection[str] diff --git a/synapse/config/push.py b/synapse/config/push.py
index bc24833702..5ca624e69e 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py
@@ -37,26 +37,12 @@ class PushConfig(Config): "group_unread_count_by_room", True ) - # There was a a 'redact_content' setting but mistakenly read from the - # 'email'section'. Check for the flag in the 'push' section, and log, - # but do not honour it to avoid nasty surprises when people upgrade. if push_config.get("redact_content") is not None: print( "The push.redact_content content option has never worked. " "Please set push.include_content if you want this behaviour" ) - # Now check for the one in the 'email' section and honour it, - # with a warning. - email_push_config = config.get("email") or {} - redact_content = email_push_config.get("redact_content") - if redact_content is not None: - print( - "The 'email.redact_content' option is deprecated: " - "please set push.include_content instead" - ) - self.push_include_content = not redact_content - # Whether to apply a random delay to outbound push. self.push_jitter_delay_ms = None push_jitter_delay = push_config.get("jitter_delay", None) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 3fa33f5373..eb1dc2dacb 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py
@@ -228,3 +228,15 @@ class RatelimitConfig(Config): config.get("remote_media_download_burst_count", "500M") ), ) + + self.rc_presence_per_user = RatelimitSettings.parse( + config, + "rc_presence.per_user", + defaults={"per_second": 0.1, "burst_count": 1}, + ) + + self.rc_delayed_event_mgmt = RatelimitSettings.parse( + config, + "rc_delayed_event_mgmt", + defaults={"per_second": 1, "burst_count": 5}, + ) diff --git a/synapse/config/redis.py b/synapse/config/redis.py
index f140538088..948c95eef7 100644 --- a/synapse/config/redis.py +++ b/synapse/config/redis.py
@@ -21,15 +21,22 @@ from typing import Any -from synapse.config._base import Config +from synapse.config._base import Config, ConfigError, read_file from synapse.types import JsonDict from synapse.util.check_dependencies import check_requirements +CONFLICTING_PASSWORD_OPTS_ERROR = """\ +You have configured both `redis.password` and `redis.password_path`. +These are mutually incompatible. +""" + class RedisConfig(Config): section = "redis" - def read_config(self, config: JsonDict, **kwargs: Any) -> None: + def read_config( + self, config: JsonDict, allow_secrets_in_config: bool, **kwargs: Any + ) -> None: redis_config = config.get("redis") or {} self.redis_enabled = redis_config.get("enabled", False) @@ -43,6 +50,22 @@ class RedisConfig(Config): self.redis_path = redis_config.get("path", None) self.redis_dbid = redis_config.get("dbid", None) self.redis_password = redis_config.get("password") + if self.redis_password and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("redis", "password"), + ) + redis_password_path = redis_config.get("password_path") + if redis_password_path: + if self.redis_password: + raise ConfigError(CONFLICTING_PASSWORD_OPTS_ERROR) + self.redis_password = read_file( + redis_password_path, + ( + "redis", + "password_path", + ), + ).strip() self.redis_use_tls = redis_config.get("use_tls", False) self.redis_certificate = redis_config.get("certificate_file", None) diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index c7f3e6d35e..b3fa500d4e 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py
@@ -27,13 +27,6 @@ from synapse.config._base import Config, ConfigError, read_file from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string_with_symbols, strtobool -NO_EMAIL_DELEGATE_ERROR = """\ -Delegation of email verification to an identity server is no longer supported. To -continue to allow users to add email addresses to their accounts, and use them for -password resets, configure Synapse with an SMTP server via the `email` setting, and -remove `account_threepid_delegates.email`. -""" - CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\ You have configured both `registration_shared_secret` and `registration_shared_secret_path`. These are mutually incompatible. @@ -43,7 +36,9 @@ You have configured both `registration_shared_secret` and class RegistrationConfig(Config): section = "registration" - def read_config(self, config: JsonDict, **kwargs: Any) -> None: + def read_config( + self, config: JsonDict, allow_secrets_in_config: bool, **kwargs: Any + ) -> None: self.enable_registration = strtobool( str(config.get("enable_registration", False)) ) @@ -56,18 +51,17 @@ class RegistrationConfig(Config): str(config.get("enable_registration_without_verification", False)) ) - self.registrations_require_3pid = config.get("registrations_require_3pid", []) - self.allowed_local_3pids = config.get("allowed_local_3pids", []) - self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) self.registration_requires_token = config.get( "registration_requires_token", False ) - self.enable_registration_token_3pid_bypass = config.get( - "enable_registration_token_3pid_bypass", False - ) # read the shared secret, either inline or from an external file self.registration_shared_secret = config.get("registration_shared_secret") + if self.registration_shared_secret and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("registration_shared_secret",), + ) registration_shared_secret_path = config.get("registration_shared_secret_path") if registration_shared_secret_path: if self.registration_shared_secret: @@ -78,16 +72,8 @@ class RegistrationConfig(Config): self.bcrypt_rounds = config.get("bcrypt_rounds", 12) - account_threepid_delegates = config.get("account_threepid_delegates") or {} - if "email" in account_threepid_delegates: - raise ConfigError(NO_EMAIL_DELEGATE_ERROR) - self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") - self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) - if config.get("invite_3pid_guest", False): - raise ConfigError("invite_3pid_guest is no longer supported") - self.auto_join_rooms = config.get("auto_join_rooms", []) for room_alias in self.auto_join_rooms: if not RoomAlias.is_valid(room_alias): @@ -147,12 +133,9 @@ class RegistrationConfig(Config): .get("msc3861", {}) .get("enabled", False) ) - self.enable_3pid_changes = config.get( - "enable_3pid_changes", not msc3861_enabled - ) - self.disable_msisdn_registration = config.get( - "disable_msisdn_registration", False + self.allow_underscore_prefixed_localpart = config.get( + "allow_underscore_prefixed_localpart", False ) session_lifetime = config.get("session_lifetime") diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 97ce6de528..fc5a90c85a 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py
@@ -22,7 +22,7 @@ import logging import os from typing import Any, Dict, List, Tuple -from urllib.request import getproxies_environment # type: ignore +from urllib.request import getproxies_environment import attr @@ -272,9 +272,7 @@ class ContentRepositoryConfig(Config): remote_media_lifetime ) - self.enable_authenticated_media = config.get( - "enable_authenticated_media", False - ) + self.enable_authenticated_media = config.get("enable_authenticated_media", True) def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str: assert data_dir_path is not None diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 704895cf9a..f0349b68f2 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py
@@ -54,9 +54,7 @@ class RoomDirectoryConfig(Config): for rule in room_list_publication_rules ] else: - self._room_list_publication_rules = [ - _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"}) - ] + self._room_list_publication_rules = [] def is_alias_creation_allowed(self, user_id: str, room_id: str, alias: str) -> bool: """Checks if the given user is allowed to create the given alias diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py deleted file mode 100644
index 9d7ef94507..0000000000 --- a/synapse/config/saml2.py +++ /dev/null
@@ -1,248 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -import logging -from typing import Any, List, Set - -from synapse.config.sso import SsoAttributeRequirement -from synapse.types import JsonDict -from synapse.util.check_dependencies import check_requirements -from synapse.util.module_loader import load_module, load_python_module - -from ._base import Config, ConfigError -from ._util import validate_config - -logger = logging.getLogger(__name__) - -DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.saml.DefaultSamlMappingProvider" -# The module that DefaultSamlMappingProvider is in was renamed, we want to -# transparently handle both the same. -LEGACY_USER_MAPPING_PROVIDER = ( - "synapse.handlers.saml_handler.DefaultSamlMappingProvider" -) - - -def _dict_merge(merge_dict: dict, into_dict: dict) -> None: - """Do a deep merge of two dicts - - Recursively merges `merge_dict` into `into_dict`: - * For keys where both `merge_dict` and `into_dict` have a dict value, the values - are recursively merged - * For all other keys, the values in `into_dict` (if any) are overwritten with - the value from `merge_dict`. - - Args: - merge_dict: dict to merge - into_dict: target dict to be modified - """ - for k, v in merge_dict.items(): - if k not in into_dict: - into_dict[k] = v - continue - - current_val = into_dict[k] - - if isinstance(v, dict) and isinstance(current_val, dict): - _dict_merge(v, current_val) - continue - - # otherwise we just overwrite - into_dict[k] = v - - -class SAML2Config(Config): - section = "saml2" - - def read_config(self, config: JsonDict, **kwargs: Any) -> None: - self.saml2_enabled = False - - saml2_config = config.get("saml2_config") - - if not saml2_config or not saml2_config.get("enabled", True): - return - - if not saml2_config.get("sp_config") and not saml2_config.get("config_path"): - return - - check_requirements("saml2") - - self.saml2_enabled = True - - attribute_requirements = saml2_config.get("attribute_requirements") or [] - self.attribute_requirements = _parse_attribute_requirements_def( - attribute_requirements - ) - - self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( - "grandfathered_mxid_source_attribute", "uid" - ) - - # refers to a SAML IdP entity ID - self.saml2_idp_entityid = saml2_config.get("idp_entityid", None) - - # IdP properties for Matrix clients - self.idp_name = saml2_config.get("idp_name", "SAML") - self.idp_icon = saml2_config.get("idp_icon") - self.idp_brand = saml2_config.get("idp_brand") - - # user_mapping_provider may be None if the key is present but has no value - ump_dict = saml2_config.get("user_mapping_provider") or {} - - # Use the default user mapping provider if not set - ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) - if ump_dict.get("module") == LEGACY_USER_MAPPING_PROVIDER: - ump_dict["module"] = DEFAULT_USER_MAPPING_PROVIDER - - # Ensure a config is present - ump_dict["config"] = ump_dict.get("config") or {} - - if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER: - # Load deprecated options for use by the default module - old_mxid_source_attribute = saml2_config.get("mxid_source_attribute") - if old_mxid_source_attribute: - logger.warning( - "The config option saml2_config.mxid_source_attribute is deprecated. " - "Please use saml2_config.user_mapping_provider.config" - ".mxid_source_attribute instead." - ) - ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute - - old_mxid_mapping = saml2_config.get("mxid_mapping") - if old_mxid_mapping: - logger.warning( - "The config option saml2_config.mxid_mapping is deprecated. Please " - "use saml2_config.user_mapping_provider.config.mxid_mapping instead." - ) - ump_dict["config"]["mxid_mapping"] = old_mxid_mapping - - # Retrieve an instance of the module's class - # Pass the config dictionary to the module for processing - ( - self.saml2_user_mapping_provider_class, - self.saml2_user_mapping_provider_config, - ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider")) - - # Ensure loaded user mapping module has defined all necessary methods - # Note parse_config() is already checked during the call to load_module - required_methods = [ - "get_saml_attributes", - "saml_response_to_user_attributes", - "get_remote_user_id", - ] - missing_methods = [ - method - for method in required_methods - if not hasattr(self.saml2_user_mapping_provider_class, method) - ] - if missing_methods: - raise ConfigError( - "Class specified by saml2_config." - "user_mapping_provider.module is missing required " - "methods: %s" % (", ".join(missing_methods),) - ) - - # Get the desired saml auth response attributes from the module - saml2_config_dict = self._default_saml_config_dict( - *self.saml2_user_mapping_provider_class.get_saml_attributes( - self.saml2_user_mapping_provider_config - ) - ) - _dict_merge( - merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict - ) - - config_path = saml2_config.get("config_path", None) - if config_path is not None: - mod = load_python_module(config_path) - config_dict_from_file = getattr(mod, "CONFIG", None) - if config_dict_from_file is None: - raise ConfigError( - "Config path specified by saml2_config.config_path does not " - "have a CONFIG property." - ) - _dict_merge(merge_dict=config_dict_from_file, into_dict=saml2_config_dict) - - import saml2.config - - self.saml2_sp_config = saml2.config.SPConfig() - self.saml2_sp_config.load(saml2_config_dict) - - # session lifetime: in milliseconds - self.saml2_session_lifetime = self.parse_duration( - saml2_config.get("saml_session_lifetime", "15m") - ) - - def _default_saml_config_dict( - self, required_attributes: Set[str], optional_attributes: Set[str] - ) -> JsonDict: - """Generate a configuration dictionary with required and optional attributes that - will be needed to process new user registration - - Args: - required_attributes: SAML auth response attributes that are - necessary to function - optional_attributes: SAML auth response attributes that can be used to add - additional information to Synapse user accounts, but are not required - - Returns: - A SAML configuration dictionary - """ - import saml2 - - if self.saml2_grandfathered_mxid_source_attribute: - optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) - optional_attributes -= required_attributes - - public_baseurl = self.root.server.public_baseurl - metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml" - response_url = public_baseurl + "_synapse/client/saml2/authn_response" - return { - "entityid": metadata_url, - "service": { - "sp": { - "endpoints": { - "assertion_consumer_service": [ - (response_url, saml2.BINDING_HTTP_POST) - ] - }, - "required_attributes": list(required_attributes), - "optional_attributes": list(optional_attributes), - # "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT, - } - }, - } - - -ATTRIBUTE_REQUIREMENTS_SCHEMA = { - "type": "array", - "items": SsoAttributeRequirement.JSON_SCHEMA, -} - - -def _parse_attribute_requirements_def( - attribute_requirements: Any, -) -> List[SsoAttributeRequirement]: - validate_config( - ATTRIBUTE_REQUIREMENTS_SCHEMA, - attribute_requirements, - config_path=("saml2_config", "attribute_requirements"), - ) - return [SsoAttributeRequirement(**x) for x in attribute_requirements] diff --git a/synapse/config/server.py b/synapse/config/server.py
index fd52c0475c..0844475b15 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2014-2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -43,12 +43,6 @@ from ._util import validate_config logger = logging.Logger(__name__) -DIRECT_TCP_ERROR = """ -Using direct TCP replication for workers is no longer supported. - -Please see https://element-hq.github.io/synapse/latest/upgrade.html#direct-tcp-replication-is-no-longer-supported-migrate-to-redis -""" - # by default, we attempt to listen on both '::' *and* '0.0.0.0' because some OSes # (Windows, macOS, other BSD/Linux where net.ipv6.bindv6only is set) will only listen # on IPv6 when '::' is set. @@ -166,13 +160,6 @@ ROOM_COMPLEXITY_TOO_GREAT = ( "to join this room." ) -METRICS_PORT_WARNING = """\ -The metrics_port configuration option is deprecated in Synapse 0.31 in favour of -a listener. Please see -https://element-hq.github.io/synapse/latest/metrics-howto.html -on how to configure the new listener. ---------------------------------------------------------------------------------""" - KNOWN_LISTENER_TYPES = { "http", @@ -215,9 +202,6 @@ class HttpListenerConfig: additional_resources: Dict[str, dict] = attr.Factory(dict) tag: Optional[str] = None request_id_header: Optional[str] = None - # If true, the listener will return CORS response headers compatible with MSC3886: - # https://github.com/matrix-org/matrix-spec-proposals/pull/3886 - experimental_cors_msc3886: bool = False @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -335,8 +319,14 @@ class ServerConfig(Config): logger.info("Using default public_baseurl %s", public_baseurl) else: self.serve_client_wellknown = True + # Ensure that public_baseurl ends with a trailing slash if public_baseurl[-1] != "/": public_baseurl += "/" + + # Scrutinize user-provided config + if not isinstance(public_baseurl, str): + raise ConfigError("Must be a string", ("public_baseurl",)) + self.public_baseurl = public_baseurl # check that public_baseurl is valid @@ -367,11 +357,6 @@ class ServerConfig(Config): "m.homeserver is not supported in extra_well_known_content, " "use public_baseurl in base config instead." ) - if "m.identity_server" in self.extra_well_known_client_content: - raise ConfigError( - "m.identity_server is not supported in extra_well_known_content, " - "use default_identity_server in base config instead." - ) # Whether to enable user presence. presence_config = config.get("presence") or {} @@ -479,10 +464,6 @@ class ServerConfig(Config): self.max_mau_value = config.get("max_mau_value", 0) self.mau_stats_only = config.get("mau_stats_only", False) - self.mau_limits_reserved_threepids = config.get( - "mau_limit_reserved_threepids", [] - ) - self.mau_trial_days = config.get("mau_trial_days", 0) self.mau_appservice_trial_days = config.get("mau_appservice_trial_days", {}) self.mau_limit_alerting = config.get("mau_limit_alerting", True) @@ -700,21 +681,6 @@ class ServerConfig(Config): pub_key=manhole_pub_key, ) - metrics_port = config.get("metrics_port") - if metrics_port: - logger.warning(METRICS_PORT_WARNING) - - self.listeners.append( - TCPListenerConfig( - port=metrics_port, - bind_addresses=[config.get("metrics_bind_host", "127.0.0.1")], - type="http", - http_options=HttpListenerConfig( - resources=[HttpResourceConfig(names=["metrics"])] - ), - ) - ) - self.cleanup_extremities_with_dummy_events = config.get( "cleanup_extremities_with_dummy_events", True ) @@ -724,18 +690,6 @@ class ServerConfig(Config): self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False) - # Inhibits the /requestToken endpoints from returning an error that might leak - # information about whether an e-mail address is in use or not on this - # homeserver, and instead return a 200 with a fake sid if this kind of error is - # met, without sending anything. - # This is a compromise between sending an email, which could be a spam vector, - # and letting the client know which email address is bound to an account and - # which one isn't. - self.request_token_inhibit_3pid_errors = config.get( - "request_token_inhibit_3pid_errors", - False, - ) - # Whitelist of domain names that given next_link parameters must have next_link_domain_whitelist: Optional[List[str]] = config.get( "next_link_domain_whitelist" @@ -780,6 +734,17 @@ class ServerConfig(Config): else: self.delete_stale_devices_after = None + # The maximum allowed delay duration for delayed events (MSC4140). + max_event_delay_duration = config.get("max_event_delay_duration") + if max_event_delay_duration is not None: + self.max_event_delay_ms: Optional[int] = self.parse_duration( + max_event_delay_duration + ) + if self.max_event_delay_ms <= 0: + raise ConfigError("max_event_delay_duration must be a positive value") + else: + self.max_event_delay_ms = None + def has_tls_listener(self) -> bool: return any(listener.is_tls() for listener in self.listeners) @@ -828,13 +793,10 @@ class ServerConfig(Config): ).lstrip() if not unsecure_listeners: - unsecure_http_bindings = ( - """- port: %(unsecure_port)s + unsecure_http_bindings = """- port: %(unsecure_port)s tls: false type: http - x_forwarded: true""" - % locals() - ) + x_forwarded: true""" % locals() if not open_private_ports: unsecure_http_bindings += ( @@ -853,16 +815,13 @@ class ServerConfig(Config): if not secure_listeners: secure_http_bindings = "" - return ( - """\ + return """\ server_name: "%(server_name)s" pid_file: %(pid_file)s listeners: %(secure_http_bindings)s %(unsecure_http_bindings)s - """ - % locals() - ) + """ % locals() def read_arguments(self, args: argparse.Namespace) -> None: if args.manhole is not None: @@ -915,24 +874,6 @@ class ServerConfig(Config): ) -def is_threepid_reserved( - reserved_threepids: List[JsonDict], threepid: JsonDict -) -> bool: - """Check the threepid against the reserved threepid config - Args: - reserved_threepids: List of reserved threepids - threepid: The threepid to test for - - Returns: - Is the threepid undertest reserved_user - """ - - for tp in reserved_threepids: - if threepid["medium"] == tp["medium"] and threepid["address"] == tp["address"]: - return True - return False - - def read_gc_thresholds( thresholds: Optional[List[Any]], ) -> Optional[Tuple[int, int, int]]: @@ -956,9 +897,6 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: raise ConfigError("Expected a dictionary", ("listeners", str(num))) listener_type = listener["type"] - # Raise a helpful error if direct TCP replication is still configured. - if listener_type == "replication": - raise ConfigError(DIRECT_TCP_ERROR, ("listeners", str(num), "type")) port = listener.get("port") socket_path = listener.get("path") @@ -999,7 +937,6 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: additional_resources=listener.get("additional_resources", {}), tag=listener.get("tag"), request_id_header=listener.get("request_id_header"), - experimental_cors_msc3886=listener.get("experimental_cors_msc3886", False), ) if socket_path: diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index d7a2187e7d..cf27a7ee13 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py
@@ -19,7 +19,7 @@ # # import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import attr @@ -43,13 +43,18 @@ class SsoAttributeRequirement: """Object describing a single requirement for SSO attributes.""" attribute: str - # If a value is not given, than the attribute must simply exist. - value: Optional[str] + # If neither `value` nor `one_of` is given, the attribute must simply exist. + value: Optional[str] = None + one_of: Optional[List[str]] = None JSON_SCHEMA = { "type": "object", - "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}}, - "required": ["attribute", "value"], + "properties": { + "attribute": {"type": "string"}, + "value": {"type": "string"}, + "one_of": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["attribute"], } diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 51dc15eb61..a48d81fdc3 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py
@@ -108,8 +108,7 @@ class TlsConfig(Config): # Raise an error if this option has been specified without any # corresponding certificates. raise ConfigError( - "federation_custom_ca_list specified without " - "any certificate files" + "federation_custom_ca_list specified without any certificate files" ) certs = [] diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c67796906f..fe4e2dc65c 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py
@@ -38,6 +38,9 @@ class UserDirectoryConfig(Config): self.user_directory_search_all_users = user_directory_config.get( "search_all_users", False ) + self.user_directory_exclude_remote_users = user_directory_config.get( + "exclude_remote_users", False + ) self.user_directory_search_prefer_local_users = user_directory_config.get( "prefer_local_users", False ) diff --git a/synapse/config/user_types.py b/synapse/config/user_types.py new file mode 100644
index 0000000000..2d9c9f7afb --- /dev/null +++ b/synapse/config/user_types.py
@@ -0,0 +1,44 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +from typing import Any, List, Optional + +from synapse.api.constants import UserTypes +from synapse.types import JsonDict + +from ._base import Config, ConfigError + + +class UserTypesConfig(Config): + section = "user_types" + + def read_config(self, config: JsonDict, **kwargs: Any) -> None: + user_types: JsonDict = config.get("user_types", {}) + + self.default_user_type: Optional[str] = user_types.get( + "default_user_type", None + ) + self.extra_user_types: List[str] = user_types.get("extra_user_types", []) + + all_user_types: List[str] = [] + all_user_types.extend(UserTypes.ALL_BUILTIN_USER_TYPES) + all_user_types.extend(self.extra_user_types) + + self.all_user_types = all_user_types + + if self.default_user_type is not None: + if self.default_user_type not in all_user_types: + raise ConfigError( + f"Default user type {self.default_user_type} is not in the list of all user types: {all_user_types}" + ) diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index 6fe43a9e32..f33602d975 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py
@@ -23,15 +23,34 @@ from typing import Any from synapse.types import JsonDict -from ._base import Config +from ._base import Config, ConfigError, read_file + +CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\ +You have configured both `turn_shared_secret` and `turn_shared_secret_path`. +These are mutually incompatible. +""" class VoipConfig(Config): section = "voip" - def read_config(self, config: JsonDict, **kwargs: Any) -> None: + def read_config( + self, config: JsonDict, allow_secrets_in_config: bool, **kwargs: Any + ) -> None: self.turn_uris = config.get("turn_uris", []) self.turn_shared_secret = config.get("turn_shared_secret") + if self.turn_shared_secret and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("turn_shared_secret",), + ) + turn_shared_secret_path = config.get("turn_shared_secret_path") + if turn_shared_secret_path: + if self.turn_shared_secret: + raise ConfigError(CONFLICTING_SHARED_SECRET_OPTS_ERROR) + self.turn_shared_secret = read_file( + turn_shared_secret_path, ("turn_shared_secret_path",) + ).strip() self.turn_username = config.get("turn_username") self.turn_password = config.get("turn_password") self.turn_user_lifetime = self.parse_duration( diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 7ecf349e4a..1685468773 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py
@@ -22,26 +22,26 @@ import argparse import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import attr -from synapse._pydantic_compat import HAS_PYDANTIC_V2 - -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import BaseModel, Extra, StrictBool, StrictInt, StrictStr -else: - from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr - +from synapse._pydantic_compat import ( + BaseModel, + Extra, + StrictBool, + StrictInt, + StrictStr, +) from synapse.config._base import ( Config, ConfigError, RoutableShardedWorkerHandlingConfig, ShardedWorkerHandlingConfig, + read_file, ) from synapse.config._util import parse_and_validate_mapping from synapse.config.server import ( - DIRECT_TCP_ERROR, TCPListenerConfig, parse_listener_def, ) @@ -65,6 +65,11 @@ configuration under `main` inside the `instance_map`. See workers documentation `https://element-hq.github.io/synapse/latest/workers.html#worker-configuration` """ +CONFLICTING_WORKER_REPLICATION_SECRET_OPTS_ERROR = """\ +Conflicting options 'worker_replication_secret' and +'worker_replication_secret_path' are both defined in config file. +""" + # This allows for a handy knob when it's time to change from 'master' to # something with less 'history' MAIN_PROCESS_INSTANCE_NAME = "master" @@ -218,7 +223,9 @@ class WorkerConfig(Config): section = "worker" - def read_config(self, config: JsonDict, **kwargs: Any) -> None: + def read_config( + self, config: JsonDict, allow_secrets_in_config: bool, **kwargs: Any + ) -> None: self.worker_app = config.get("worker_app") # Canonicalise worker_app so that master always has None @@ -237,12 +244,24 @@ class WorkerConfig(Config): raise ConfigError("worker_log_config must be a string") self.worker_log_config = worker_log_config - # The port on the main synapse for TCP replication - if "worker_replication_port" in config: - raise ConfigError(DIRECT_TCP_ERROR, ("worker_replication_port",)) - # The shared secret used for authentication when connecting to the main synapse. - self.worker_replication_secret = config.get("worker_replication_secret", None) + worker_replication_secret = config.get("worker_replication_secret", None) + if worker_replication_secret and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("worker_replication_secret",), + ) + worker_replication_secret_path = config.get( + "worker_replication_secret_path", None + ) + if worker_replication_secret_path: + if worker_replication_secret: + raise ConfigError(CONFLICTING_WORKER_REPLICATION_SECRET_OPTS_ERROR) + self.worker_replication_secret = read_file( + worker_replication_secret_path, ("worker_replication_secret_path",) + ).strip() + else: + self.worker_replication_secret = worker_replication_secret self.worker_name = config.get("worker_name", self.worker_app) self.instance_name = self.worker_name or MAIN_PROCESS_INSTANCE_NAME @@ -328,10 +347,11 @@ class WorkerConfig(Config): ) # type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently - self.instance_map: Dict[ - str, InstanceLocationConfig - ] = parse_and_validate_mapping( - instance_map, InstanceLocationConfig # type: ignore[arg-type] + self.instance_map: Dict[str, InstanceLocationConfig] = ( + parse_and_validate_mapping( + instance_map, + InstanceLocationConfig, # type: ignore[arg-type] + ) ) # Map from type of streams to source, c.f. WriterLocations. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 8c301e077c..643d2d4e66 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py
@@ -589,7 +589,7 @@ class BaseV2KeyFetcher(KeyFetcher): % (server_name,) ) - for key_id, key_data in response_json["old_verify_keys"].items(): + for key_id, key_data in response_json.get("old_verify_keys", {}).items(): if is_signing_algorithm_supported(key_id): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index f5abcde2db..5999c264dc 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py
@@ -32,6 +32,7 @@ from typing import ( Mapping, MutableMapping, Optional, + Protocol, Set, Tuple, Union, @@ -41,7 +42,6 @@ from typing import ( from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json -from typing_extensions import Protocol from unpaddedbase64 import decode_base64 from synapse.api.constants import ( @@ -388,6 +388,7 @@ LENIENT_EVENT_BYTE_LIMITS_ROOM_VERSIONS = { RoomVersions.V9, RoomVersions.V10, RoomVersions.MSC1767v10, + RoomVersions.MSC3757v10, } @@ -565,6 +566,7 @@ def _is_membership_change_allowed( logger.debug( "_is_membership_change_allowed: %s", { + "caller_membership": caller.membership if caller else None, "caller_in_room": caller_in_room, "caller_invited": caller_invited, "caller_knocked": caller_knocked, @@ -676,7 +678,8 @@ def _is_membership_change_allowed( and join_rule == JoinRules.KNOCK_RESTRICTED ) ): - if not caller_in_room and not caller_invited: + # You can only join the room if you are invited or are already in the room. + if not (caller_in_room or caller_invited): raise AuthError(403, "You are not invited to this room.") else: # TODO (erikj): may_join list @@ -790,9 +793,10 @@ def get_send_level( def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool: + state_key = event.get_state_key() power_levels_event = get_power_level_event(auth_events) - send_level = get_send_level(event.type, event.get("state_key"), power_levels_event) + send_level = get_send_level(event.type, state_key, power_levels_event) user_level = get_user_power_level(event.user_id, auth_events) if user_level < send_level: @@ -803,11 +807,34 @@ def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> b errcode=Codes.INSUFFICIENT_POWER, ) - # Check state_key - if hasattr(event, "state_key"): - if event.state_key.startswith("@"): - if event.state_key != event.user_id: - raise AuthError(403, "You are not allowed to set others state") + if ( + state_key is not None + and state_key.startswith("@") + and state_key != event.user_id + ): + if event.room_version.msc3757_enabled: + try: + colon_idx = state_key.index(":", 1) + suffix_idx = state_key.find("_", colon_idx + 1) + state_key_user_id = ( + state_key[:suffix_idx] if suffix_idx != -1 else state_key + ) + if not UserID.is_valid(state_key_user_id): + raise ValueError + except ValueError: + raise SynapseError( + 400, + "State key neither equals a valid user ID, nor starts with one plus an underscore", + errcode=Codes.BAD_JSON, + ) + if ( + # sender is owner of the state key + state_key_user_id == event.user_id + # sender has higher PL than the owner of the state key + or user_level > get_user_power_level(state_key_user_id, auth_events) + ): + return True + raise AuthError(403, "You are not allowed to set others state") return True @@ -887,7 +914,8 @@ def _check_power_levels( raise SynapseError(400, f"{v!r} must be an integer.") if k in {"events", "notifications", "users"}: if not isinstance(v, collections.abc.Mapping) or not all( - type(v) is int for v in v.values() # noqa: E721 + type(v) is int + for v in v.values() # noqa: E721 ): raise SynapseError( 400, @@ -958,8 +986,7 @@ def _check_power_levels( if old_level == user_level: raise AuthError( 403, - "You don't have permission to remove ops level equal " - "to your own", + "You don't have permission to remove ops level equal to your own", ) # Check if the old and new levels are greater than the user level diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 2e56b671f0..a85e66d6bf 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py
@@ -22,7 +22,6 @@ import abc import collections.abc -import os from typing import ( TYPE_CHECKING, Any, @@ -30,6 +29,7 @@ from typing import ( Generic, Iterable, List, + Literal, Optional, Tuple, Type, @@ -39,30 +39,29 @@ from typing import ( ) import attr -from typing_extensions import Literal from unpaddedbase64 import encode_base64 -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.synapse_rust.events import EventInternalMetadata from synapse.types import JsonDict, StrCollection from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze -from synapse.util.stringutils import strtobool if TYPE_CHECKING: from synapse.events.builder import EventBuilder -# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents -# bugs where we accidentally share e.g. signature dicts. However, converting a -# dict to frozen_dicts is expensive. -# -# NOTE: This is overridden by the configuration by the Synapse worker apps, but -# for the sake of tests, it is set here while it cannot be configured on the -# homeserver object itself. -USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) +USE_FROZEN_DICTS = False +""" +Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents +bugs where we accidentally share e.g. signature dicts. However, converting a +dict to frozen_dicts is expensive. +NOTE: This is overridden by the configuration by the Synapse worker apps, but +for the sake of tests, it is set here because it cannot be configured on the +homeserver object itself. +""" T = TypeVar("T") @@ -325,12 +324,17 @@ class EventBase(metaclass=abc.ABCMeta): def __repr__(self) -> str: rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else "" + conditional_membership_string = "" + if self.get("type") == EventTypes.Member: + conditional_membership_string = f"membership={self.membership}, " + return ( f"<{self.__class__.__name__} " f"{rejection}" f"event_id={self.event_id}, " f"type={self.get('type')}, " f"state_key={self.get('state_key')}, " + f"{conditional_membership_string}" f"outlier={self.internal_metadata.is_outlier()}" ">" ) diff --git a/synapse/events/auto_accept_invites.py b/synapse/events/auto_accept_invites.py
index d88ec51d9d..4295107c47 100644 --- a/synapse/events/auto_accept_invites.py +++ b/synapse/events/auto_accept_invites.py
@@ -66,49 +66,66 @@ class InviteAutoAccepter: event: The incoming event. """ # Check if the event is an invite for a local user. - is_invite_for_local_user = ( - event.type == EventTypes.Member - and event.is_state() - and event.membership == Membership.INVITE - and self._api.is_mine(event.state_key) - ) + if ( + event.type != EventTypes.Member + or event.is_state() is False + or event.membership != Membership.INVITE + or self._api.is_mine(event.state_key) is False + ): + return # Only accept invites for direct messages if the configuration mandates it. is_direct_message = event.content.get("is_direct", False) - is_allowed_by_direct_message_rules = ( - not self._config.accept_invites_only_for_direct_messages - or is_direct_message is True - ) + if ( + self._config.accept_invites_only_for_direct_messages + and is_direct_message is False + ): + return # Only accept invites from remote users if the configuration mandates it. is_from_local_user = self._api.is_mine(event.sender) - is_allowed_by_local_user_rules = ( - not self._config.accept_invites_only_from_local_users - or is_from_local_user is True - ) - if ( - is_invite_for_local_user - and is_allowed_by_direct_message_rules - and is_allowed_by_local_user_rules + self._config.accept_invites_only_from_local_users + and is_from_local_user is False ): - # Make the user join the room. We run this as a background process to circumvent a race condition - # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12) - run_as_background_process( - "retry_make_join", - self._retry_make_join, - event.state_key, - event.state_key, - event.room_id, - "join", - bg_start_span=False, - ) + return - if is_direct_message: - # Mark this room as a direct message! - await self._mark_room_as_direct_message( - event.state_key, event.sender, event.room_id - ) + # Check the user is activated. + recipient = await self._api.get_userinfo_by_id(event.state_key) + + # Ignore if the user doesn't exist. + if recipient is None: + return + + # Never accept invites for deactivated users. + if recipient.is_deactivated: + return + + # Never accept invites for suspended users. + if recipient.suspended: + return + + # Never accept invites for locked users. + if recipient.locked: + return + + # Make the user join the room. We run this as a background process to circumvent a race condition + # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12) + run_as_background_process( + "retry_make_join", + self._retry_make_join, + event.state_key, + event.state_key, + event.room_id, + "join", + bg_start_span=False, + ) + + if is_direct_message: + # Mark this room as a direct message! + await self._mark_room_as_direct_message( + event.state_key, event.sender, event.room_id + ) async def _mark_room_as_direct_message( self, user_id: str, dm_user_id: str, room_id: str diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 10ef01131b..76df083d69 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import attr from signedjson.types import SigningKey -from synapse.api.constants import MAX_DEPTH +from synapse.api.constants import MAX_DEPTH, EventTypes from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, EventFormatVersions, @@ -109,6 +109,19 @@ class EventBuilder: def is_state(self) -> bool: return self._state_key is not None + def is_mine_id(self, user_id: str) -> bool: + """Determines whether a user ID or room alias originates from this homeserver. + + Returns: + `True` if the hostname part of the user ID or room alias matches this + homeserver. + `False` otherwise, or if the user ID or room alias is malformed. + """ + localpart_hostname = user_id.split(":", 1) + if len(localpart_hostname) < 2: + return False + return localpart_hostname[1] == self._hostname + async def build( self, prev_event_ids: List[str], @@ -142,6 +155,46 @@ class EventBuilder: self, state_ids ) + # Check for out-of-band membership that may have been exposed on `/sync` but + # the events have not been de-outliered yet so they won't be part of the + # room state yet. + # + # This helps in situations where a remote homeserver invites a local user to + # a room that we're already participating in; and we've persisted the invite + # as an out-of-band membership (outlier), but it hasn't been pushed to us as + # part of a `/send` transaction yet and de-outliered. This also helps for + # any of the other out-of-band membership transitions. + # + # As an optimization, we could check if the room state already includes a + # non-`leave` membership event, then we can assume the membership event has + # been de-outliered and we don't need to check for an out-of-band + # membership. But we don't have the necessary information from a + # `StateMap[str]` and we'll just have to take the hit of this extra lookup + # for any membership event for now. + if self.type == EventTypes.Member and self.is_mine_id(self.state_key): + ( + _membership, + member_event_id, + ) = await self._store.get_local_current_membership_for_user_in_room( + user_id=self.state_key, + room_id=self.room_id, + ) + # There is no need to check if the membership is actually an + # out-of-band membership (`outlier`) as we would end up with the + # same result either way (adding the member event to the + # `auth_event_ids`). + if ( + member_event_id is not None + # We only need to be careful about duplicating the event in the + # `auth_event_ids` list (duplicate `type`/`state_key` is part of the + # authorization rules) + and member_event_id not in auth_event_ids + ): + auth_event_ids.append(member_event_id) + # Also make sure to point to the previous membership event that will + # allow this one to happen so the computed state works out. + prev_event_ids.append(member_event_id) + format_version = self.room_version.event_format # The types of auth/prev events changes between event versions. prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]] diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py
index 9cb053cd8e..9713b141bc 100644 --- a/synapse/events/presence_router.py +++ b/synapse/events/presence_router.py
@@ -80,7 +80,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None: # All methods that the module provides should be async, but this wasn't enforced # in the old module system, so we wrap them if needed def async_wrapper( - f: Optional[Callable[P, R]] + f: Optional[Callable[P, R]], ) -> Optional[Callable[P, Awaitable[R]]]: # f might be None if the callback isn't implemented by the module. In this # case we don't want to register a callback at all so we return None. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 6b70ea94d1..0bca4c188b 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py
@@ -248,7 +248,7 @@ class EventContext(UnpersistedEventContextBase): @tag_args async def get_current_state_ids( self, state_filter: Optional["StateFilter"] = None - ) -> Optional[StateMap[str]]: + ) -> StateMap[str]: """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -256,13 +256,12 @@ class EventContext(UnpersistedEventContextBase): not make it into the room state. This method will raise an exception if ``rejected`` is set. + It is also an error to access this for an outlier event. + Arg: state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules Returns: - Returns None if state_group is None, which happens when the associated - event is an outlier. - Maps a (type, state_key) to the event ID of the state event matching this tuple. """ @@ -300,7 +299,8 @@ class EventContext(UnpersistedEventContextBase): this tuple. """ - assert self.state_group_before_event is not None + if self.state_group_before_event is None: + return {} return await self._storage.state.get_state_ids_for_group( self.state_group_before_event, state_filter ) @@ -504,7 +504,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase): def _encode_state_group_delta( - state_group_delta: Dict[Tuple[int, int], StateMap[str]] + state_group_delta: Dict[Tuple[int, int], StateMap[str]], ) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]: if not state_group_delta: return [] @@ -517,7 +517,7 @@ def _encode_state_group_delta( def _decode_state_group_delta( - input: List[Tuple[int, int, List[Tuple[str, str, str]]]] + input: List[Tuple[int, int, List[Tuple[str, str, str]]]], ) -> Dict[Tuple[int, int], StateMap[str]]: if not input: return {} @@ -544,7 +544,7 @@ def _encode_state_dict( def _decode_state_dict( - input: Optional[List[Tuple[str, str, str]]] + input: Optional[List[Tuple[str, str, str]]], ) -> Optional[StateMap[str]]: """Decodes a state dict encoded using `_encode_state_dict` above""" if input is None: diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 54f94add4d..eb18ba2db7 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py
@@ -40,6 +40,8 @@ import attr from canonicaljson import encode_canonical_json from synapse.api.constants import ( + CANONICALJSON_MAX_INT, + CANONICALJSON_MIN_INT, MAX_PDU_SIZE, EventContentFields, EventTypes, @@ -61,9 +63,6 @@ SPLIT_FIELD_REGEX = re.compile(r"\\*\.") # Find escaped characters, e.g. those with a \ in front of them. ESCAPE_SEQUENCE_PATTERN = re.compile(r"\\(.)") -CANONICALJSON_MAX_INT = (2**53) - 1 -CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT - # Module API callback that allows adding fields to the unsigned section of # events that are sent to clients. diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 73b63b77f2..d1fb026cd6 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py
@@ -19,17 +19,11 @@ # # import collections.abc -from typing import TYPE_CHECKING, List, Type, Union, cast +from typing import List, Type, Union, cast import jsonschema -from synapse._pydantic_compat import HAS_PYDANTIC_V2 - -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import Field, StrictBool, StrictStr -else: - from pydantic import Field, StrictBool, StrictStr - +from synapse._pydantic_compat import Field, StrictBool, StrictStr from synapse.api.constants import ( MAX_ALIAS_LENGTH, EventContentFields, @@ -92,9 +86,7 @@ class EventValidator: # Depending on the room version, ensure the data is spec compliant JSON. if event.room_version.strict_canonicaljson: - # Note that only the client controlled portion of the event is - # checked, since we trust the portions of the event we created. - validate_canonicaljson(event.content) + validate_canonicaljson(event.get_pdu_json()) if event.type == EventTypes.Aliases: if "aliases" in event.content: diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index a571eff590..61e28bff66 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py
@@ -19,5 +19,4 @@ # # -""" This package includes all the federation specific logic. -""" +"""This package includes all the federation specific logic.""" diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index b101a389ef..45593430e8 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py
@@ -20,7 +20,7 @@ # # import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Sequence from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -29,6 +29,8 @@ from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.keyring import Keyring from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event, validate_canonicaljson +from synapse.federation.units import filter_pdus_for_valid_depth +from synapse.handlers.room_policy import RoomPolicyHandler from synapse.http.servlet import assert_params_in_dict from synapse.logging.opentracing import log_kv, trace from synapse.types import JsonDict, get_domain_from_id @@ -63,6 +65,24 @@ class FederationBase: self._clock = hs.get_clock() self._storage_controllers = hs.get_storage_controllers() + # We need to define this lazily otherwise we get a cyclic dependency. + # self._policy_handler = hs.get_room_policy_handler() + self._policy_handler: Optional[RoomPolicyHandler] = None + + def _lazily_get_policy_handler(self) -> RoomPolicyHandler: + """Lazily get the room policy handler. + + This is required to avoid an import cycle: RoomPolicyHandler requires a + FederationClient, which requires a FederationBase, which requires a + RoomPolicyHandler. + + Returns: + RoomPolicyHandler: The room policy handler. + """ + if self._policy_handler is None: + self._policy_handler = self.hs.get_room_policy_handler() + return self._policy_handler + @trace async def _check_sigs_and_hash( self, @@ -79,6 +99,10 @@ class FederationBase: Also runs the event through the spam checker; if it fails, redacts the event and flags it as soft-failed. + Also checks that the event is allowed by the policy server, if the room uses + a policy server. If the event is not allowed, the event is flagged as + soft-failed but not redacted. + Args: room_version: The room version of the PDU pdu: the event to be checked @@ -144,6 +168,17 @@ class FederationBase: ) return redacted_event + policy_allowed = await self._lazily_get_policy_handler().is_event_allowed(pdu) + if not policy_allowed: + logger.warning( + "Event not allowed by policy server, soft-failing %s", pdu.event_id + ) + pdu.internal_metadata.soft_failed = True + # Note: we don't redact the event so admins can inspect the event after the + # fact. Other processes may redact the event, but that won't be applied to + # the database copy of the event until the server's config requires it. + return pdu + spam_check = await self._spam_checker_module_callbacks.check_event_for_spam(pdu) if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: @@ -267,6 +302,15 @@ def _is_invite_via_3pid(event: EventBase) -> bool: ) +def parse_events_from_pdu_json( + pdus_json: Sequence[JsonDict], room_version: RoomVersion +) -> List[EventBase]: + return [ + event_from_pdu_json(pdu_json, room_version) + for pdu_json in filter_pdus_for_valid_depth(pdus_json) + ] + + def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventBase: """Construct an EventBase from an event json received over federation diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 7d80ff6998..7c485aa7e0 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -68,12 +68,14 @@ from synapse.federation.federation_base import ( FederationBase, InvalidEventSignatureError, event_from_pdu_json, + parse_events_from_pdu_json, ) from synapse.federation.transport.client import SendJoinResponse from synapse.http.client import is_unknown_endpoint from synapse.http.types import QueryParams from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id +from synapse.types.handlers.policy_server import RECOMMENDATION_OK, RECOMMENDATION_SPAM from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -349,7 +351,7 @@ class FederationClient(FederationBase): room_version = await self.store.get_room_version(room_id) - pdus = [event_from_pdu_json(p, room_version) for p in transaction_data_pdus] + pdus = parse_events_from_pdu_json(transaction_data_pdus, room_version) # Check signatures and hash of pdus, removing any from the list that fail checks pdus[:] = await self._check_sigs_and_hash_for_pulled_events_and_fetch( @@ -393,9 +395,7 @@ class FederationClient(FederationBase): transaction_data, ) - pdu_list: List[EventBase] = [ - event_from_pdu_json(p, room_version) for p in transaction_data["pdus"] - ] + pdu_list = parse_events_from_pdu_json(transaction_data["pdus"], room_version) if pdu_list and pdu_list[0]: pdu = pdu_list[0] @@ -424,6 +424,62 @@ class FederationClient(FederationBase): @trace @tag_args + async def get_pdu_policy_recommendation( + self, destination: str, pdu: EventBase, timeout: Optional[int] = None + ) -> str: + """Requests that the destination server (typically a policy server) + check the event and return its recommendation on how to handle the + event. + + If the policy server could not be contacted or the policy server + returned an unknown recommendation, this returns an OK recommendation. + This type fixing behaviour is done because the typical caller will be + in a critical call path and would generally interpret a `None` or similar + response as "weird value; don't care; move on without taking action". We + just frontload that logic here. + + + Args: + destination: The remote homeserver to ask (a policy server) + pdu: The event to check + timeout: How long to try (in ms) the destination for before + giving up. None indicates no timeout. + + Returns: + The policy recommendation, or RECOMMENDATION_OK if the policy server was + uncontactable or returned an unknown recommendation. + """ + + logger.debug( + "get_pdu_policy_recommendation for event_id=%s from %s", + pdu.event_id, + destination, + ) + + try: + res = await self.transport_layer.get_policy_recommendation_for_pdu( + destination, pdu, timeout=timeout + ) + recommendation = res.get("recommendation") + if not isinstance(recommendation, str): + raise InvalidResponseError("recommendation is not a string") + if recommendation not in (RECOMMENDATION_OK, RECOMMENDATION_SPAM): + logger.warning( + "get_pdu_policy_recommendation: unknown recommendation: %s", + recommendation, + ) + return RECOMMENDATION_OK + return recommendation + except Exception as e: + logger.warning( + "get_pdu_policy_recommendation: server %s responded with error, assuming OK recommendation: %s", + destination, + e, + ) + return RECOMMENDATION_OK + + @trace + @tag_args async def get_pdu( self, destinations: Collection[str], @@ -809,7 +865,7 @@ class FederationClient(FederationBase): room_version = await self.store.get_room_version(room_id) - auth_chain = [event_from_pdu_json(p, room_version) for p in res["auth_chain"]] + auth_chain = parse_events_from_pdu_json(res["auth_chain"], room_version) signed_auth = await self._check_sigs_and_hash_for_pulled_events_and_fetch( destination, auth_chain, room_version=room_version @@ -1529,9 +1585,7 @@ class FederationClient(FederationBase): room_version = await self.store.get_room_version(room_id) - events = [ - event_from_pdu_json(e, room_version) for e in content.get("events", []) - ] + events = parse_events_from_pdu_json(content.get("events", []), room_version) signed_events = await self._check_sigs_and_hash_for_pulled_events_and_fetch( destination, events, room_version=room_version diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 1932fa82a4..2f2c78babc 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -66,7 +66,7 @@ from synapse.federation.federation_base import ( event_from_pdu_json, ) from synapse.federation.persistence import TransactionActions -from synapse.federation.units import Edu, Transaction +from synapse.federation.units import Edu, Transaction, serialize_and_filter_pdus from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( @@ -469,7 +469,12 @@ class FederationServer(FederationBase): logger.info("Ignoring PDU: %s", e) continue - event = event_from_pdu_json(p, room_version) + try: + event = event_from_pdu_json(p, room_version) + except SynapseError as e: + logger.info("Ignoring PDU for failing to deserialize: %s", e) + continue + pdus_by_room.setdefault(room_id, []).append(event) if event.origin_server_ts > newest_pdu_ts: @@ -636,8 +641,8 @@ class FederationServer(FederationBase): ) return { - "pdus": [pdu.get_pdu_json() for pdu in pdus], - "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], + "pdus": serialize_and_filter_pdus(pdus), + "auth_chain": serialize_and_filter_pdus(auth_chain), } async def on_pdu_request( @@ -696,6 +701,12 @@ class FederationServer(FederationBase): pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, pdu.room_id) + if await self._spam_checker_module_callbacks.should_drop_federated_event(pdu): + logger.info( + "Federated event contains spam, dropping %s", + pdu.event_id, + ) + raise SynapseError(403, Codes.FORBIDDEN) try: pdu = await self._check_sigs_and_hash(room_version, pdu) except InvalidEventSignatureError as e: @@ -761,8 +772,8 @@ class FederationServer(FederationBase): event_json = event.get_pdu_json(time_now) resp = { "event": event_json, - "state": [p.get_pdu_json(time_now) for p in state_events], - "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events], + "state": serialize_and_filter_pdus(state_events, time_now), + "auth_chain": serialize_and_filter_pdus(auth_chain_events, time_now), "members_omitted": caller_supports_partial_state, } @@ -1005,7 +1016,7 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() auth_pdus = await self.handler.on_event_auth(event_id) - res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} + res = {"auth_chain": serialize_and_filter_pdus(auth_pdus, time_now)} return 200, res async def on_query_client_keys( @@ -1090,7 +1101,7 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() - return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} + return {"events": serialize_and_filter_pdus(missing_events, time_now)} async def on_openid_userinfo(self, token: str) -> Optional[str]: ts_now_ms = self._clock.time_msec() diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 0bfde00315..8340b48503 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py
@@ -20,7 +20,7 @@ # # -""" This module contains all the persistence actions done by the federation +"""This module contains all the persistence actions done by the federation package. These actions are mostly only used by the :py:mod:`.replication` module. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 1888480881..2eef7b707d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py
@@ -139,14 +139,13 @@ from typing import ( Hashable, Iterable, List, + Literal, Optional, - Set, Tuple, ) import attr from prometheus_client import Counter -from typing_extensions import Literal from twisted.internet import defer @@ -170,7 +169,13 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.types import JsonDict, ReadReceipt, RoomStreamToken, StrCollection +from synapse.types import ( + JsonDict, + ReadReceipt, + RoomStreamToken, + StrCollection, + get_domain_from_id, +) from synapse.util import Clock from synapse.util.metrics import Measure from synapse.util.retryutils import filter_destinations_by_retry_limiter @@ -297,12 +302,10 @@ class _DestinationWakeupQueue: # being woken up. _MAX_TIME_IN_QUEUE = 30.0 - # The maximum duration in seconds between waking up consecutive destination - # queues. - _MAX_DELAY = 0.1 - sender: "FederationSender" = attr.ib() clock: Clock = attr.ib() + max_delay_s: int = attr.ib() + queue: "OrderedDict[str, Literal[None]]" = attr.ib(factory=OrderedDict) processing: bool = attr.ib(default=False) @@ -332,13 +335,15 @@ class _DestinationWakeupQueue: # We also add an upper bound to the delay, to gracefully handle the # case where the queue only has a few entries in it. current_sleep_seconds = min( - self._MAX_DELAY, self._MAX_TIME_IN_QUEUE / len(self.queue) + self.max_delay_s, self._MAX_TIME_IN_QUEUE / len(self.queue) ) while self.queue: destination, _ = self.queue.popitem(last=False) queue = self.sender._get_per_destination_queue(destination) + if queue is None: + continue if not queue._new_data_to_send: # The per destination queue has already been woken up. @@ -416,19 +421,14 @@ class FederationSender(AbstractFederationSender): self._is_processing = False self._last_poked_id = -1 - # map from room_id to a set of PerDestinationQueues which we believe are - # awaiting a call to flush_read_receipts_for_room. The presence of an entry - # here for a given room means that we are rate-limiting RR flushes to that room, - # and that there is a pending call to _flush_rrs_for_room in the system. - self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {} + self._external_cache = hs.get_external_cache() - self._rr_txn_interval_per_room_ms = ( - 1000.0 - / hs.config.ratelimiting.federation_rr_transactions_per_room_per_second + rr_txn_interval_per_room_s = ( + 1.0 / hs.config.ratelimiting.federation_rr_transactions_per_room_per_second + ) + self._destination_wakeup_queue = _DestinationWakeupQueue( + self, self.clock, max_delay_s=rr_txn_interval_per_room_s ) - - self._external_cache = hs.get_external_cache() - self._destination_wakeup_queue = _DestinationWakeupQueue(self, self.clock) # Regularly wake up destinations that have outstanding PDUs to be caught up self.clock.looping_call_now( @@ -438,12 +438,23 @@ class FederationSender(AbstractFederationSender): self._wake_destinations_needing_catchup, ) - def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: + def _get_per_destination_queue( + self, destination: str + ) -> Optional[PerDestinationQueue]: """Get or create a PerDestinationQueue for the given destination Args: destination: server_name of remote server + + Returns: + None if the destination is not allowed by the federation whitelist. + Otherwise a PerDestinationQueue for this destination. """ + if not self.hs.config.federation.is_domain_allowed_according_to_federation_whitelist( + destination + ): + return None + queue = self._per_destination_queues.get(destination) if not queue: queue = PerDestinationQueue(self.hs, self._transaction_manager, destination) @@ -720,6 +731,16 @@ class FederationSender(AbstractFederationSender): # track the fact that we have a PDU for these destinations, # to allow us to perform catch-up later on if the remote is unreachable # for a while. + # Filter out any destinations not present in the federation_domain_whitelist, if + # the whitelist exists. These destinations should not be sent to so let's not + # waste time or space keeping track of events destined for them. + destinations = [ + d + for d in destinations + if self.hs.config.federation.is_domain_allowed_according_to_federation_whitelist( + d + ) + ] await self.store.store_destination_rooms_entries( destinations, pdu.room_id, @@ -734,7 +755,12 @@ class FederationSender(AbstractFederationSender): ) for destination in destinations: - self._get_per_destination_queue(destination).send_pdu(pdu) + queue = self._get_per_destination_queue(destination) + # We expect `queue` to not be None as we already filtered out + # non-whitelisted destinations above. + assert queue is not None + + queue.send_pdu(pdu) async def send_read_receipt(self, receipt: ReadReceipt) -> None: """Send a RR to any other servers in the room @@ -745,37 +771,48 @@ class FederationSender(AbstractFederationSender): # Some background on the rate-limiting going on here. # - # It turns out that if we attempt to send out RRs as soon as we get them from - # a client, then we end up trying to do several hundred Hz of federation - # transactions. (The number of transactions scales as O(N^2) on the size of a - # room, since in a large room we have both more RRs coming in, and more servers - # to send them to.) + # It turns out that if we attempt to send out RRs as soon as we get them + # from a client, then we end up trying to do several hundred Hz of + # federation transactions. (The number of transactions scales as O(N^2) + # on the size of a room, since in a large room we have both more RRs + # coming in, and more servers to send them to.) # - # This leads to a lot of CPU load, and we end up getting behind. The solution - # currently adopted is as follows: + # This leads to a lot of CPU load, and we end up getting behind. The + # solution currently adopted is to differentiate between receipts and + # destinations we should immediately send to, and those we can trickle + # the receipts to. # - # The first receipt in a given room is sent out immediately, at time T0. Any - # further receipts are, in theory, batched up for N seconds, where N is calculated - # based on the number of servers in the room to achieve a transaction frequency - # of around 50Hz. So, for example, if there were 100 servers in the room, then - # N would be 100 / 50Hz = 2 seconds. + # The current logic is to send receipts out immediately if: + # - the room is "small", i.e. there's only N servers to send receipts + # to, and so sending out the receipts immediately doesn't cause too + # much load; or + # - the receipt is for an event that happened recently, as users + # notice if receipts are delayed when they know other users are + # currently reading the room; or + # - the receipt is being sent to the server that sent the event, so + # that users see receipts for their own receipts quickly. # - # Then, after T+N, we flush out any receipts that have accumulated, and restart - # the timer to flush out more receipts at T+2N, etc. If no receipts accumulate, - # we stop the cycle and go back to the start. + # For destinations that we should delay sending the receipt to, we queue + # the receipts up to be sent in the next transaction, but don't trigger + # a new transaction to be sent. We then add the destination to the + # `DestinationWakeupQueue`, which will slowly iterate over each + # destination and trigger a new transaction to be sent. # - # However, in practice, it is often possible to flush out receipts earlier: in - # particular, if we are sending a transaction to a given server anyway (for - # example, because we have a PDU or a RR in another room to send), then we may - # as well send out all of the pending RRs for that server. So it may be that - # by the time we get to T+N, we don't actually have any RRs left to send out. - # Nevertheless we continue to buffer up RRs for the room in question until we - # reach the point that no RRs arrive between timer ticks. + # However, in practice, it is often possible to send out delayed + # receipts earlier: in particular, if we are sending a transaction to a + # given server anyway (for example, because we have a PDU or a RR in + # another room to send), then we may as well send out all of the pending + # RRs for that server. So it may be that by the time we get to waking up + # the destination, we don't actually have any RRs left to send out. # - # For even more background, see https://github.com/matrix-org/synapse/issues/4730. + # For even more background, see + # https://github.com/matrix-org/synapse/issues/4730. room_id = receipt.room_id + # Local read receipts always have 1 event ID. + event_id = receipt.event_ids[0] + # Work out which remote servers should be poked and poke them. domains_set = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id @@ -797,49 +834,55 @@ class FederationSender(AbstractFederationSender): if not domains: return - queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(room_id) + # We now split which domains we want to wake up immediately vs which we + # want to delay waking up. + immediate_domains: StrCollection + delay_domains: StrCollection - # if there is no flush yet scheduled, we will send out these receipts with - # immediate flushes, and schedule the next flush for this room. - if queues_pending_flush is not None: - logger.debug("Queuing receipt for: %r", domains) + if len(domains) < 10: + # For "small" rooms send to all domains immediately + immediate_domains = domains + delay_domains = () else: - logger.debug("Sending receipt to: %r", domains) - self._schedule_rr_flush_for_room(room_id, len(domains)) + metadata = await self.store.get_metadata_for_event( + receipt.room_id, event_id + ) + assert metadata is not None - for domain in domains: - queue = self._get_per_destination_queue(domain) - queue.queue_read_receipt(receipt) + sender_domain = get_domain_from_id(metadata.sender) - # if there is already a RR flush pending for this room, then make sure this - # destination is registered for the flush - if queues_pending_flush is not None: - queues_pending_flush.add(queue) + if self.clock.time_msec() - metadata.received_ts < 60_000: + # We always send receipts for recent messages immediately + immediate_domains = domains + delay_domains = () else: - queue.flush_read_receipts_for_room(room_id) - - def _schedule_rr_flush_for_room(self, room_id: str, n_domains: int) -> None: - # that is going to cause approximately len(domains) transactions, so now back - # off for that multiplied by RR_TXN_INTERVAL_PER_ROOM - backoff_ms = self._rr_txn_interval_per_room_ms * n_domains - - logger.debug("Scheduling RR flush in %s in %d ms", room_id, backoff_ms) - self.clock.call_later(backoff_ms, self._flush_rrs_for_room, room_id) - self._queues_awaiting_rr_flush_by_room[room_id] = set() - - def _flush_rrs_for_room(self, room_id: str) -> None: - queues = self._queues_awaiting_rr_flush_by_room.pop(room_id) - logger.debug("Flushing RRs in %s to %s", room_id, queues) - - if not queues: - # no more RRs arrived for this room; we are done. - return + # Otherwise, we delay waking up all destinations except for the + # sender's domain. + immediate_domains = [] + delay_domains = [] + for domain in domains: + if domain == sender_domain: + immediate_domains.append(domain) + else: + delay_domains.append(domain) + + for domain in immediate_domains: + # Add to destination queue and wake the destination up + queue = self._get_per_destination_queue(domain) + if queue is None: + continue + queue.queue_read_receipt(receipt) + queue.attempt_new_transaction() - # schedule the next flush - self._schedule_rr_flush_for_room(room_id, len(queues)) + for domain in delay_domains: + # Add to destination queue... + queue = self._get_per_destination_queue(domain) + if queue is None: + continue + queue.queue_read_receipt(receipt) - for queue in queues: - queue.flush_read_receipts_for_room(room_id) + # ... and schedule the destination to be woken up. + self._destination_wakeup_queue.add_to_queue(domain) async def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] @@ -871,9 +914,10 @@ class FederationSender(AbstractFederationSender): if self.is_mine_server_name(destination): continue - self._get_per_destination_queue(destination).send_presence( - states, start_loop=False - ) + queue = self._get_per_destination_queue(destination) + if queue is None: + continue + queue.send_presence(states, start_loop=False) self._destination_wakeup_queue.add_to_queue(destination) @@ -923,6 +967,8 @@ class FederationSender(AbstractFederationSender): return queue = self._get_per_destination_queue(edu.destination) + if queue is None: + return if key: queue.send_keyed_edu(edu, key) else: @@ -947,9 +993,15 @@ class FederationSender(AbstractFederationSender): for destination in destinations: if immediate: - self._get_per_destination_queue(destination).attempt_new_transaction() + queue = self._get_per_destination_queue(destination) + if queue is None: + continue + queue.attempt_new_transaction() else: - self._get_per_destination_queue(destination).mark_new_data() + queue = self._get_per_destination_queue(destination) + if queue is None: + continue + queue.mark_new_data() self._destination_wakeup_queue.add_to_queue(destination) def wake_destination(self, destination: str) -> None: @@ -968,7 +1020,9 @@ class FederationSender(AbstractFederationSender): ): return - self._get_per_destination_queue(destination).attempt_new_transaction() + queue = self._get_per_destination_queue(destination) + if queue is not None: + queue.attempt_new_transaction() @staticmethod def get_current_token() -> int: @@ -1013,6 +1067,9 @@ class FederationSender(AbstractFederationSender): d for d in destinations_to_wake if self._federation_shard_config.should_handle(self._instance_name, d) + and self.hs.config.federation.is_domain_allowed_according_to_federation_whitelist( + d + ) ] for destination in destinations_to_wake: diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index d097e65ea7..b3f65e8237 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py
@@ -156,7 +156,6 @@ class PerDestinationQueue: # Each receipt can only have a single receipt per # (room ID, receipt type, user ID, thread ID) tuple. self._pending_receipt_edus: List[Dict[str, Dict[str, Dict[str, dict]]]] = [] - self._rrs_pending_flush = False # stream_id of last successfully sent to-device message. # NB: may be a long or an int. @@ -258,15 +257,7 @@ class PerDestinationQueue: } ) - def flush_read_receipts_for_room(self, room_id: str) -> None: - # If there are any pending receipts for this room then force-flush them - # in a new transaction. - for edu in self._pending_receipt_edus: - if room_id in edu: - self._rrs_pending_flush = True - self.attempt_new_transaction() - # No use in checking remaining EDUs if the room was found. - break + self.mark_new_data() def send_keyed_edu(self, edu: Edu, key: Hashable) -> None: self._pending_edus_keyed[(edu.edu_type, key)] = edu @@ -603,12 +594,9 @@ class PerDestinationQueue: self._destination, last_successful_stream_ordering ) - def _get_receipt_edus(self, force_flush: bool, limit: int) -> Iterable[Edu]: + def _get_receipt_edus(self, limit: int) -> Iterable[Edu]: if not self._pending_receipt_edus: return - if not force_flush and not self._rrs_pending_flush: - # not yet time for this lot - return # Send at most limit EDUs for receipts. for content in self._pending_receipt_edus[:limit]: @@ -747,7 +735,7 @@ class _TransactionQueueManager: ) # Add read receipt EDUs. - pending_edus.extend(self.queue._get_receipt_edus(force_flush=False, limit=5)) + pending_edus.extend(self.queue._get_receipt_edus(limit=5)) edu_limit = MAX_EDUS_PER_TRANSACTION - len(pending_edus) # Next, prioritize to-device messages so that existing encryption channels @@ -795,13 +783,6 @@ class _TransactionQueueManager: if not self._pdus and not pending_edus: return [], [] - # if we've decided to send a transaction anyway, and we have room, we - # may as well send any pending RRs - if edu_limit: - pending_edus.extend( - self.queue._get_receipt_edus(force_flush=True, limit=edu_limit) - ) - if self._pdus: self._last_stream_ordering = self._pdus[ -1 diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 206e91ed14..62bf96ce91 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -143,6 +143,33 @@ class TransportLayerClient: destination, path=path, timeout=timeout, try_trailing_slash_on_400=True ) + async def get_policy_recommendation_for_pdu( + self, destination: str, event: EventBase, timeout: Optional[int] = None + ) -> JsonDict: + """Requests the policy recommendation for the given pdu from the given policy server. + + Args: + destination: The host name of the remote homeserver checking the event. + event: The event to check. + timeout: How long to try (in ms) the destination for before giving up. + None indicates no timeout. + + Returns: + The full recommendation object from the remote server. + """ + logger.debug( + "get_policy_recommendation_for_pdu dest=%s, event_id=%s", + destination, + event.event_id, + ) + return await self.client.post_json( + destination=destination, + path=f"/_matrix/policy/unstable/org.matrix.msc4284/event/{event.event_id}/check", + data=event.get_pdu_json(), + ignore_backoff=True, + timeout=timeout, + ) + async def backfill( self, destination: str, room_id: str, event_tuples: Collection[str], limit: int ) -> Optional[Union[JsonDict, list]]: diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index 43102567db..174d02ab6b 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py
@@ -20,9 +20,7 @@ # # import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type - -from typing_extensions import Literal +from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Tuple, Type from synapse.api.errors import FederationDeniedError, SynapseError from synapse.federation.transport.server._base import ( diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index 9094201da0..cba309635b 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py
@@ -113,7 +113,7 @@ class Authenticator: ): raise AuthenticationError( HTTPStatus.UNAUTHORIZED, - "Destination mismatch in auth header", + f"Destination mismatch in auth header, received: {destination!r}", Codes.UNAUTHORIZED, ) if ( diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 20f87c885e..eb96ff27f9 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py
@@ -24,6 +24,7 @@ from typing import ( TYPE_CHECKING, Dict, List, + Literal, Mapping, Optional, Sequence, @@ -32,8 +33,6 @@ from typing import ( Union, ) -from typing_extensions import Literal - from synapse.api.constants import Direction, EduTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions @@ -509,6 +508,9 @@ class FederationV2InviteServlet(BaseFederationServerServlet): event = content["event"] invite_room_state = content.get("invite_room_state", []) + if not isinstance(invite_room_state, list): + invite_room_state = [] + # Synapse expects invite_room_state to be in unsigned, as it is in v1 # API @@ -859,7 +861,6 @@ class FederationMediaThumbnailServlet(BaseFederationServerServlet): request: SynapseRequest, media_id: str, ) -> None: - width = parse_integer(request, "width", required=True) height = parse_integer(request, "height", required=True) method = parse_string(request, "method", "scale") diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index b2c8ba5887..3bb5f824b7 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py
@@ -19,15 +19,17 @@ # # -""" Defines the JSON structure of the protocol units used by the server to +"""Defines the JSON structure of the protocol units used by the server to server protocol. """ import logging -from typing import List, Optional +from typing import List, Optional, Sequence import attr +from synapse.api.constants import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT +from synapse.events import EventBase from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -104,8 +106,28 @@ class Transaction: result = { "origin": self.origin, "origin_server_ts": self.origin_server_ts, - "pdus": self.pdus, + "pdus": filter_pdus_for_valid_depth(self.pdus), } if self.edus: result["edus"] = self.edus return result + + +def filter_pdus_for_valid_depth(pdus: Sequence[JsonDict]) -> List[JsonDict]: + filtered_pdus = [] + for pdu in pdus: + # Drop PDUs that have a depth that is outside of the range allowed + # by canonical json. + if ( + "depth" in pdu + and CANONICALJSON_MIN_INT <= pdu["depth"] <= CANONICALJSON_MAX_INT + ): + filtered_pdus.append(pdu) + + return filtered_pdus + + +def serialize_and_filter_pdus( + pdus: Sequence[EventBase], time_now: Optional[int] = None +) -> List[JsonDict]: + return filter_pdus_for_valid_depth([pdu.get_pdu_json(time_now) for pdu in pdus]) diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py
index 89e944bc17..37cc3d3ff5 100644 --- a/synapse/handlers/account.py +++ b/synapse/handlers/account.py
@@ -118,10 +118,10 @@ class AccountHandler: } if self._use_account_validity_in_account_status: - status["org.matrix.expired"] = ( - await self._account_validity_handler.is_user_expired( - user_id.to_string() - ) + status[ + "org.matrix.expired" + ] = await self._account_validity_handler.is_user_expired( + user_id.to_string() ) return status diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 97a463d8d0..228132db48 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py
@@ -33,7 +33,7 @@ from synapse.replication.http.account_data import ( ReplicationRemoveUserAccountDataRestServlet, ) from synapse.streams import EventSource -from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID +from synapse.types import JsonDict, JsonMapping, StrCollection, StreamKeyType, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -253,7 +253,7 @@ class AccountDataHandler: return response["max_stream_id"] async def add_tag_to_room( - self, user_id: str, room_id: str, tag: str, content: JsonDict + self, user_id: str, room_id: str, tag: str, content: JsonMapping ) -> int: """Add a tag to a room for a user. diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 7004d95a0f..e40ca3e73f 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py
@@ -18,8 +18,6 @@ # # -import email.mime.multipart -import email.utils import logging from typing import TYPE_CHECKING, List, Optional, Tuple @@ -40,18 +38,13 @@ class AccountValidityHandler: self.hs = hs self.config = hs.config self.store = hs.get_datastores().main - self.send_email_handler = hs.get_send_email_handler() self.clock = hs.get_clock() - self._app_name = hs.config.email.email_app_name self._module_api_callbacks = hs.get_module_api_callbacks().account_validity self._account_validity_enabled = ( hs.config.account_validity.account_validity_enabled ) - self._account_validity_renew_by_email_enabled = ( - hs.config.account_validity.account_validity_renew_by_email_enabled - ) self._account_validity_period = None if self._account_validity_enabled: @@ -59,21 +52,6 @@ class AccountValidityHandler: hs.config.account_validity.account_validity_period ) - if ( - self._account_validity_enabled - and self._account_validity_renew_by_email_enabled - ): - # Don't do email-specific configuration if renewal by email is disabled. - self._template_html = hs.config.email.account_validity_template_html - self._template_text = hs.config.email.account_validity_template_text - self._renew_email_subject = ( - hs.config.account_validity.account_validity_renew_email_subject - ) - - # Check the renewal emails to send and send them every 30min. - if hs.config.worker.run_background_tasks: - self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) - async def is_user_expired(self, user_id: str) -> bool: """Checks if a user has expired against third-party modules. @@ -120,125 +98,6 @@ class AccountValidityHandler: for callback in self._module_api_callbacks.on_user_login_callbacks: await callback(user_id, auth_provider_type, auth_provider_id) - @wrap_as_background_process("send_renewals") - async def _send_renewal_emails(self) -> None: - """Gets the list of users whose account is expiring in the amount of time - configured in the ``renew_at`` parameter from the ``account_validity`` - configuration, and sends renewal emails to all of these users as long as they - have an email 3PID attached to their account. - """ - expiring_users = await self.store.get_users_expiring_soon() - - if expiring_users: - for user_id, expiration_ts_ms in expiring_users: - await self._send_renewal_email( - user_id=user_id, expiration_ts=expiration_ts_ms - ) - - async def send_renewal_email_to_user(self, user_id: str) -> None: - """ - Send a renewal email for a specific user. - - Args: - user_id: The user ID to send a renewal email for. - - Raises: - SynapseError if the user is not set to renew. - """ - # If a module supports sending a renewal email from here, do that, otherwise do - # the legacy dance. - if self._module_api_callbacks.on_legacy_send_mail_callback is not None: - await self._module_api_callbacks.on_legacy_send_mail_callback(user_id) - return - - if not self._account_validity_renew_by_email_enabled: - raise AuthError( - 403, "Account renewal via email is disabled on this server." - ) - - expiration_ts = await self.store.get_expiration_ts_for_user(user_id) - - # If this user isn't set to be expired, raise an error. - if expiration_ts is None: - raise SynapseError(400, "User has no expiration time: %s" % (user_id,)) - - await self._send_renewal_email(user_id, expiration_ts) - - async def _send_renewal_email(self, user_id: str, expiration_ts: int) -> None: - """Sends out a renewal email to every email address attached to the given user - with a unique link allowing them to renew their account. - - Args: - user_id: ID of the user to send email(s) to. - expiration_ts: Timestamp in milliseconds for the expiration date of - this user's account (used in the email templates). - """ - addresses = await self._get_email_addresses_for_user(user_id) - - # Stop right here if the user doesn't have at least one email address. - # In this case, they will have to ask their server admin to renew their - # account manually. - # We don't need to do a specific check to make sure the account isn't - # deactivated, as a deactivated account isn't supposed to have any - # email address attached to it. - if not addresses: - return - - try: - user_display_name = await self.store.get_profile_displayname( - UserID.from_string(user_id) - ) - if user_display_name is None: - user_display_name = user_id - except StoreError: - user_display_name = user_id - - renewal_token = await self._get_renewal_token(user_id) - url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( - self.hs.config.server.public_baseurl, - renewal_token, - ) - - template_vars = { - "display_name": user_display_name, - "expiration_ts": expiration_ts, - "url": url, - } - - html_text = self._template_html.render(**template_vars) - plain_text = self._template_text.render(**template_vars) - - for address in addresses: - raw_to = email.utils.parseaddr(address)[1] - - await self.send_email_handler.send_email( - email_address=raw_to, - subject=self._renew_email_subject, - app_name=self._app_name, - html=html_text, - text=plain_text, - ) - - await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) - - async def _get_email_addresses_for_user(self, user_id: str) -> List[str]: - """Retrieve the list of email addresses attached to a user's account. - - Args: - user_id: ID of the user to lookup email addresses for. - - Returns: - Email addresses for this account. - """ - threepids = await self.store.user_get_threepids(user_id) - - addresses = [] - for threepid in threepids: - if threepid.medium == "email": - addresses.append(threepid.address) - - return addresses - async def _get_renewal_token(self, user_id: str) -> str: """Generates a 32-byte long random string that will be inserted into the user's renewal email's unique link, then saves it into the database. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index b44e862493..5467d129bd 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py
@@ -21,13 +21,34 @@ import abc import logging -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, +) import attr -from synapse.api.constants import Direction, Membership +from synapse.api.constants import Direction, EventTypes, Membership +from synapse.api.errors import SynapseError from synapse.events import EventBase -from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo +from synapse.types import ( + JsonMapping, + Requester, + RoomStreamToken, + ScheduledTask, + StateMap, + TaskStatus, + UserID, + UserInfo, + create_requester, +) from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -35,6 +56,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +REDACT_ALL_EVENTS_ACTION_NAME = "redact_all_events" + class AdminHandler: def __init__(self, hs: "HomeServer"): @@ -43,6 +66,22 @@ class AdminHandler: self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state self._msc3866_enabled = hs.config.experimental.msc3866.enabled + self.event_creation_handler = hs.get_event_creation_handler() + self._task_scheduler = hs.get_task_scheduler() + + self._task_scheduler.register_action( + self._redact_all_events, REDACT_ALL_EVENTS_ACTION_NAME + ) + + self.hs = hs + + async def get_redact_task(self, redact_id: str) -> Optional[ScheduledTask]: + """Get the current status of an active redaction process + + Args: + redact_id: redact_id returned by start_redact_events. + """ + return await self._task_scheduler.get_task(redact_id) async def get_whois(self, user: UserID) -> JsonMapping: connections = [] @@ -85,6 +124,7 @@ class AdminHandler: "consent_ts": user_info.consent_ts, "user_type": user_info.user_type, "is_guest": user_info.is_guest, + "suspended": user_info.suspended, } if self._msc3866_enabled: @@ -93,7 +133,6 @@ class AdminHandler: # Add additional user metadata profile = await self._store.get_profileinfo(user) - threepids = await self._store.user_get_threepids(user.to_string()) external_ids = [ ({"auth_provider": auth_provider, "external_id": external_id}) for auth_provider, external_id in await self._store.get_external_ids_by_user( @@ -102,7 +141,6 @@ class AdminHandler: ] user_info_dict["displayname"] = profile.display_name user_info_dict["avatar_url"] = profile.avatar_url - user_info_dict["threepids"] = [attr.asdict(t) for t in threepids] user_info_dict["external_ids"] = external_ids user_info_dict["erased"] = await self._store.is_user_erased(user.to_string()) @@ -197,14 +235,16 @@ class AdminHandler: # events that we have and then filtering, this isn't the most # efficient method perhaps but it does guarantee we get everything. while True: - events, _ = ( - await self._store.paginate_room_events_by_topological_ordering( - room_id=room_id, - from_key=from_key, - to_key=to_key, - limit=100, - direction=Direction.FORWARDS, - ) + ( + events, + _, + _, + ) = await self._store.paginate_room_events_by_topological_ordering( + room_id=room_id, + from_key=from_key, + to_key=to_key, + limit=100, + direction=Direction.FORWARDS, ) if not events: break @@ -311,6 +351,155 @@ class AdminHandler: return writer.finished() + async def start_redact_events( + self, + user_id: str, + rooms: list, + requester: JsonMapping, + reason: Optional[str], + limit: Optional[int], + ) -> str: + """ + Start a task redacting the events of the given user in the given rooms + + Args: + user_id: the user ID of the user whose events should be redacted + rooms: the rooms in which to redact the user's events + requester: the user requesting the events + reason: reason for requesting the redaction, ie spam, etc + limit: limit on the number of events in each room to redact + + Returns: + a unique ID which can be used to query the status of the task + """ + active_tasks = await self._task_scheduler.get_tasks( + actions=[REDACT_ALL_EVENTS_ACTION_NAME], + resource_id=user_id, + statuses=[TaskStatus.ACTIVE], + ) + + if len(active_tasks) > 0: + raise SynapseError( + 400, "Redact already in progress for user %s" % (user_id,) + ) + + if not limit: + limit = 1000 + + redact_id = await self._task_scheduler.schedule_task( + REDACT_ALL_EVENTS_ACTION_NAME, + resource_id=user_id, + params={ + "rooms": rooms, + "requester": requester, + "user_id": user_id, + "reason": reason, + "limit": limit, + }, + ) + + logger.info( + "starting redact events with redact_id %s", + redact_id, + ) + + return redact_id + + async def _redact_all_events( + self, task: ScheduledTask + ) -> Tuple[TaskStatus, Optional[Mapping[str, Any]], Optional[str]]: + """ + Task to redact all of a users events in the given rooms, tracking which, if any, events + whose redaction failed + """ + + assert task.params is not None + rooms = task.params.get("rooms") + assert rooms is not None + + r = task.params.get("requester") + assert r is not None + admin = Requester.deserialize(self._store, r) + + user_id = task.params.get("user_id") + assert user_id is not None + + # puppet the user if they're ours, otherwise use admin to redact + requester = create_requester( + user_id if self.hs.is_mine_id(user_id) else admin.user.to_string(), + authenticated_entity=admin.user.to_string(), + ) + + reason = task.params.get("reason") + limit = task.params.get("limit") + assert limit is not None + + result: Mapping[str, Any] = ( + task.result if task.result else {"failed_redactions": {}} + ) + for room in rooms: + room_version = await self._store.get_room_version(room) + event_ids = await self._store.get_events_sent_by_user_in_room( + user_id, + room, + limit, + ["m.room.member", "m.room.message", "m.room.encrypted"], + ) + if not event_ids: + # nothing to redact in this room + continue + + events = await self._store.get_events_as_list(event_ids) + for event in events: + # we care about join events but not other membership events + if event.type == "m.room.member": + content = event.content + if content: + if content.get("membership") == Membership.JOIN: + pass + else: + continue + relations = await self._store.get_relations_for_event( + room, event.event_id, event, event_type=EventTypes.Redaction + ) + + # if we've already successfully redacted this event then skip processing it + if relations[0]: + continue + + event_dict = { + "type": EventTypes.Redaction, + "content": {"reason": reason} if reason else {}, + "room_id": room, + "sender": requester.user.to_string(), + } + if room_version.updated_redaction_rules: + event_dict["content"]["redacts"] = event.event_id + else: + event_dict["redacts"] = event.event_id + + try: + # set the prev event to the offending message to allow for redactions + # to be processed in the case where the user has been kicked/banned before + # redactions are requested + ( + redaction, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + event_dict, + prev_event_ids=[event.event_id], + ratelimit=False, + ) + except Exception as ex: + logger.info( + f"Redaction of event {event.event_id} failed due to: {ex}" + ) + result["failed_redactions"][event.event_id] = str(ex) + await self._task_scheduler.update_task(task.id, result=result) + + return TaskStatus.COMPLETE, result, None + class ExfiltrationWriter(metaclass=abc.ABCMeta): """Interface used to specify how to write exported data.""" diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 4b33e1330d..b7d1033351 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -896,10 +896,10 @@ class ApplicationServicesHandler: results = await make_deferred_yieldable( defer.DeferredList( [ - run_in_background( + run_in_background( # type: ignore[call-overload] self.appservice_api.claim_client_keys, # We know this must be an app service. - self.store.get_app_service_by_id(service_id), # type: ignore[arg-type] + self.store.get_app_service_by_id(service_id), service_query, ) for service_id, service_query in query_by_appservice.items() @@ -952,10 +952,10 @@ class ApplicationServicesHandler: results = await make_deferred_yieldable( defer.DeferredList( [ - run_in_background( + run_in_background( # type: ignore[call-overload] self.appservice_api.query_keys, # We know this must be an app service. - self.store.get_app_service_by_id(service_id), # type: ignore[arg-type] + self.store.get_app_service_by_id(service_id), service_query, ) for service_id, service_query in query_by_appservice.items() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a1fab99f6b..d37324cc46 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -79,9 +79,7 @@ from synapse.storage.databases.main.registration import ( from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import delay_cancellation, maybe_awaitable -from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import base62_encode -from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.module_api import ModuleApi @@ -153,42 +151,9 @@ def convert_client_dict_legacy_fields_to_identifier( return identifier -def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]: - """ - Convert a phone login identifier type to a generic threepid identifier. - - Args: - identifier: Login identifier dict of type 'm.id.phone' - - Returns: - An equivalent m.id.thirdparty identifier dict - """ - if "country" not in identifier or ( - # The specification requires a "phone" field, while Synapse used to require a "number" - # field. Accept both for backwards compatibility. - "phone" not in identifier - and "number" not in identifier - ): - raise SynapseError( - 400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM - ) - - # Accept both "phone" and "number" as valid keys in m.id.phone - phone_number = identifier.get("phone", identifier["number"]) - - # Convert user-provided phone number to a consistent representation - msisdn = phone_number_to_msisdn(identifier["country"], phone_number) - - return { - "type": "m.id.thirdparty", - "medium": "msisdn", - "address": msisdn, - } - - @attr.s(slots=True, auto_attribs=True) class SsoLoginExtraAttributes: - """Data we track about SAML2 sessions""" + """Data we track about SAML2 sessions""" # Not other SSO types...? # time the session was created, in milliseconds creation_time: int @@ -1195,70 +1160,11 @@ class AuthHandler: # convert phone type identifiers to generic threepids if identifier_dict["type"] == "m.id.phone": - identifier_dict = login_id_phone_to_thirdparty(identifier_dict) + raise SynapseError(400, "Third party identifiers are not supported on this server.") # convert threepid identifiers to user IDs if identifier_dict["type"] == "m.id.thirdparty": - address = identifier_dict.get("address") - medium = identifier_dict.get("medium") - - if medium is None or address is None: - raise SynapseError(400, "Invalid thirdparty identifier") - - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See add_threepid in synapse/handlers/auth.py) - if medium == "email": - try: - address = canonicalise_email(address) - except ValueError as e: - raise SynapseError(400, str(e)) - - # We also apply account rate limiting using the 3PID as a key, as - # otherwise using 3PID bypasses the ratelimiting based on user ID. - if ratelimit: - await self._failed_login_attempts_ratelimiter.ratelimit( - None, (medium, address), update=False - ) - - # Check for login providers that support 3pid login types - if login_type == LoginType.PASSWORD: - # we've already checked that there is a (valid) password field - assert isinstance(password, str) - ( - canonical_user_id, - callback_3pid, - ) = await self.check_password_provider_3pid(medium, address, password) - if canonical_user_id: - # Authentication through password provider and 3pid succeeded - return canonical_user_id, callback_3pid - - # No password providers were able to handle this 3pid - # Check local store - user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( - medium, address - ) - if not user_id: - logger.warning( - "unknown 3pid identifier medium %s, address %r", medium, address - ) - # We mark that we've failed to log in here, as - # `check_password_provider_3pid` might have returned `None` due - # to an incorrect password, rather than the account not - # existing. - # - # If it returned None but the 3PID was bound then we won't hit - # this code path, which is fine as then the per-user ratelimit - # will kick in below. - if ratelimit: - await self._failed_login_attempts_ratelimiter.can_do_action( - None, (medium, address) - ) - raise LoginError( - 403, msg=INVALID_USERNAME_OR_PASSWORD, errcode=Codes.FORBIDDEN - ) - - identifier_dict = {"type": "m.id.user", "user": user_id} + raise SynapseError(400, "Third party identifiers are not supported on this server.") # by this point, the identifier should be an m.id.user: if it's anything # else, we haven't understood it. @@ -1548,83 +1454,6 @@ class AuthHandler: user_id, (token_id for _, token_id, _ in tokens_and_devices) ) - async def add_threepid( - self, user_id: str, medium: str, address: str, validated_at: int - ) -> None: - """ - Adds an association between a user's Matrix ID and a third-party ID (email, - phone number). - - Args: - user_id: The ID of the user to associate. - medium: The medium of the third-party ID (email, msisdn). - address: The address of the third-party ID (i.e. an email address). - validated_at: The timestamp in ms of when the validation that the user owns - this third-party ID occurred. - """ - # check if medium has a valid value - if medium not in ["email", "msisdn"]: - raise SynapseError( - code=400, - msg=("'%s' is not a valid value for 'medium'" % (medium,)), - errcode=Codes.INVALID_PARAM, - ) - - # 'Canonicalise' email addresses down to lower case. - # We've now moving towards the homeserver being the entity that - # is responsible for validating threepids used for resetting passwords - # on accounts, so in future Synapse will gain knowledge of specific - # types (mediums) of threepid. For now, we still use the existing - # infrastructure, but this is the start of synapse gaining knowledge - # of specific types of threepid (and fixes the fact that checking - # for the presence of an email address during password reset was - # case sensitive). - if medium == "email": - address = canonicalise_email(address) - - await self.store.user_add_threepid( - user_id, medium, address, validated_at, self.hs.get_clock().time_msec() - ) - - # Inform Synapse modules that a 3PID association has been created. - await self._third_party_rules.on_add_user_third_party_identifier( - user_id, medium, address - ) - - # Deprecated method for informing Synapse modules that a 3PID association - # has successfully been created. - await self._third_party_rules.on_threepid_bind(user_id, medium, address) - - async def delete_local_threepid( - self, user_id: str, medium: str, address: str - ) -> None: - """Deletes an association between a third-party ID and a user ID from the local - database. This method does not unbind the association from any identity servers. - - If `medium` is 'email' and a pusher is associated with this third-party ID, the - pusher will also be deleted. - - Args: - user_id: ID of user to remove the 3pid from. - medium: The medium of the 3pid being removed: "email" or "msisdn". - address: The 3pid address to remove. - """ - # 'Canonicalise' email addresses as per above - if medium == "email": - address = canonicalise_email(address) - - await self.store.user_delete_threepid(user_id, medium, address) - - # Inform Synapse modules that a 3PID association has been deleted. - await self._third_party_rules.on_remove_user_third_party_identifier( - user_id, medium, address - ) - - if medium == "email": - await self.store.delete_pusher_by_app_id_pushkey_user_id( - app_id="m.email", pushkey=address, user_id=user_id - ) - async def hash(self, password: str) -> str: """Computes a secure hash of password. diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py deleted file mode 100644
index cc3d641b7d..0000000000 --- a/synapse/handlers/cas.py +++ /dev/null
@@ -1,412 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2020 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# -import logging -import urllib.parse -from typing import TYPE_CHECKING, Dict, List, Optional -from xml.etree import ElementTree as ET - -import attr - -from twisted.web.client import PartialDownloadError - -from synapse.api.errors import HttpResponseException -from synapse.handlers.sso import MappingException, UserAttributes -from synapse.http.site import SynapseRequest -from synapse.types import UserID, map_username_to_mxid_localpart - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class CasError(Exception): - """Used to catch errors when validating the CAS ticket.""" - - def __init__(self, error: str, error_description: Optional[str] = None): - self.error = error - self.error_description = error_description - - def __str__(self) -> str: - if self.error_description: - return f"{self.error}: {self.error_description}" - return self.error - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class CasResponse: - username: str - attributes: Dict[str, List[Optional[str]]] - - -class CasHandler: - """ - Utility class for to handle the response from a CAS SSO service. - - Args: - hs - """ - - def __init__(self, hs: "HomeServer"): - self.hs = hs - self._hostname = hs.hostname - self._store = hs.get_datastores().main - self._auth_handler = hs.get_auth_handler() - self._registration_handler = hs.get_registration_handler() - - self._cas_server_url = hs.config.cas.cas_server_url - self._cas_service_url = hs.config.cas.cas_service_url - self._cas_protocol_version = hs.config.cas.cas_protocol_version - self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute - self._cas_required_attributes = hs.config.cas.cas_required_attributes - self._cas_enable_registration = hs.config.cas.cas_enable_registration - self._cas_allow_numeric_ids = hs.config.cas.cas_allow_numeric_ids - self._cas_numeric_ids_prefix = hs.config.cas.cas_numeric_ids_prefix - - self._http_client = hs.get_proxied_http_client() - - # identifier for the external_ids table - self.idp_id = "cas" - - # user-facing name of this auth provider - self.idp_name = hs.config.cas.idp_name - - # MXC URI for icon for this auth provider - self.idp_icon = hs.config.cas.idp_icon - - # optional brand identifier for this auth provider - self.idp_brand = hs.config.cas.idp_brand - - self._sso_handler = hs.get_sso_handler() - - self._sso_handler.register_identity_provider(self) - - def _build_service_param(self, args: Dict[str, str]) -> str: - """ - Generates a value to use as the "service" parameter when redirecting or - querying the CAS service. - - Args: - args: Additional arguments to include in the final redirect URL. - - Returns: - The URL to use as a "service" parameter. - """ - return "%s?%s" % ( - self._cas_service_url, - urllib.parse.urlencode(args), - ) - - async def _validate_ticket( - self, ticket: str, service_args: Dict[str, str] - ) -> CasResponse: - """ - Validate a CAS ticket with the server, and return the parsed the response. - - Args: - ticket: The CAS ticket from the client. - service_args: Additional arguments to include in the service URL. - Should be the same as those passed to `handle_redirect_request`. - - Raises: - CasError: If there's an error parsing the CAS response. - - Returns: - The parsed CAS response. - """ - if self._cas_protocol_version == 3: - uri = self._cas_server_url + "/p3/proxyValidate" - else: - uri = self._cas_server_url + "/proxyValidate" - args = { - "ticket": ticket, - "service": self._build_service_param(service_args), - } - try: - body = await self._http_client.get_raw(uri, args) - except PartialDownloadError as pde: - # Twisted raises this error if the connection is closed, - # even if that's being used old-http style to signal end-of-data - # Assertion is for mypy's benefit. Error.response is Optional[bytes], - # but a PartialDownloadError should always have a non-None response. - assert pde.response is not None - body = pde.response - except HttpResponseException as e: - description = ( - 'Authorization server responded with a "{status}" error ' - "while exchanging the authorization code." - ).format(status=e.code) - raise CasError("server_error", description) from e - - return self._parse_cas_response(body) - - def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse: - """ - Retrieve the user and other parameters from the CAS response. - - Args: - cas_response_body: The response from the CAS query. - - Raises: - CasError: If there's an error parsing the CAS response. - - Returns: - The parsed CAS response. - """ - - # Ensure the response is valid. - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise CasError( - "missing_service_response", - "root of CAS response is not serviceResponse", - ) - - success = root[0].tag.endswith("authenticationSuccess") - if not success: - raise CasError("unsucessful_response", "Unsuccessful CAS response") - - # Iterate through the nodes and pull out the user and any extra attributes. - user = None - attributes: Dict[str, List[Optional[str]]] = {} - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - # if numeric user IDs are allowed and username is numeric then we add the prefix so Synapse can handle it - if self._cas_allow_numeric_ids and user is not None and user.isdigit(): - user = f"{self._cas_numeric_ids_prefix}{user}" - if child.tag.endswith("attributes"): - for attribute in child: - # ElementTree library expands the namespace in - # attribute tags to the full URL of the namespace. - # We don't care about namespace here and it will always - # be encased in curly braces, so we remove them. - tag = attribute.tag - if "}" in tag: - tag = tag.split("}")[1] - attributes.setdefault(tag, []).append(attribute.text) - - # Ensure a user was found. - if user is None: - raise CasError("no_user", "CAS response does not contain user") - - return CasResponse(user, attributes) - - async def handle_redirect_request( - self, - request: SynapseRequest, - client_redirect_url: Optional[bytes], - ui_auth_session_id: Optional[str] = None, - ) -> str: - """Generates a URL for the CAS server where the client should be redirected. - - Args: - request: the incoming HTTP request - client_redirect_url: the URL that we should redirect the - client to after login (or None for UI Auth). - ui_auth_session_id: The session ID of the ongoing UI Auth (or - None if this is a login). - - Returns: - URL to redirect to - """ - - if ui_auth_session_id: - service_args = {"session": ui_auth_session_id} - else: - assert client_redirect_url - service_args = {"redirectUrl": client_redirect_url.decode("utf8")} - - args = urllib.parse.urlencode( - {"service": self._build_service_param(service_args)} - ) - - return "%s/login?%s" % (self._cas_server_url, args) - - async def handle_ticket( - self, - request: SynapseRequest, - ticket: str, - client_redirect_url: Optional[str], - session: Optional[str], - ) -> None: - """ - Called once the user has successfully authenticated with the SSO. - Validates a CAS ticket sent by the client and completes the auth process. - - If the user interactive authentication session is provided, marks the - UI Auth session as complete, then returns an HTML page notifying the - user they are done. - - Otherwise, this registers the user if necessary, and then returns a - redirect (with a login token) to the client. - - Args: - request: the incoming request from the browser. We'll - respond to it with a redirect or an HTML page. - - ticket: The CAS ticket provided by the client. - - client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given. - This should be the same as the redirectUrl from the original `/login/sso/redirect` request. - - session: The session parameter from the `/cas/ticket` HTTP request, if given. - This should be the UI Auth session id. - """ - args = {} - if client_redirect_url: - args["redirectUrl"] = client_redirect_url - if session: - args["session"] = session - - try: - cas_response = await self._validate_ticket(ticket, args) - except CasError as e: - logger.exception("Could not validate ticket") - self._sso_handler.render_error(request, e.error, e.error_description, 401) - return - - await self._handle_cas_response( - request, cas_response, client_redirect_url, session - ) - - async def _handle_cas_response( - self, - request: SynapseRequest, - cas_response: CasResponse, - client_redirect_url: Optional[str], - session: Optional[str], - ) -> None: - """Handle a CAS response to a ticket request. - - Assumes that the response has been validated. Maps the user onto an MXID, - registering them if necessary, and returns a response to the browser. - - Args: - request: the incoming request from the browser. We'll respond to it with an - HTML page or a redirect - - cas_response: The parsed CAS response. - - client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given. - This should be the same as the redirectUrl from the original `/login/sso/redirect` request. - - session: The session parameter from the `/cas/ticket` HTTP request, if given. - This should be the UI Auth session id. - """ - - # first check if we're doing a UIA - if session: - return await self._sso_handler.complete_sso_ui_auth_request( - self.idp_id, - cas_response.username, - session, - request, - ) - - # otherwise, we're handling a login request. - - # Ensure that the attributes of the logged in user meet the required - # attributes. - if not self._sso_handler.check_required_attributes( - request, cas_response.attributes, self._cas_required_attributes - ): - return - - # Call the mapper to register/login the user - - # If this not a UI auth request than there must be a redirect URL. - assert client_redirect_url is not None - - try: - await self._complete_cas_login(cas_response, request, client_redirect_url) - except MappingException as e: - logger.exception("Could not map user") - self._sso_handler.render_error(request, "mapping_error", str(e)) - - async def _complete_cas_login( - self, - cas_response: CasResponse, - request: SynapseRequest, - client_redirect_url: str, - ) -> None: - """ - Given a CAS response, complete the login flow - - Retrieves the remote user ID, registers the user if necessary, and serves - a redirect back to the client with a login-token. - - Args: - cas_response: The parsed CAS response. - request: The request to respond to - client_redirect_url: The redirect URL passed in by the client. - - Raises: - MappingException if there was a problem mapping the response to a user. - RedirectException: some mapping providers may raise this if they need - to redirect to an interstitial page. - """ - # Note that CAS does not support a mapping provider, so the logic is hard-coded. - localpart = map_username_to_mxid_localpart(cas_response.username) - - async def cas_response_to_user_attributes(failures: int) -> UserAttributes: - """ - Map from CAS attributes to user attributes. - """ - # Due to the grandfathering logic matching any previously registered - # mxids it isn't expected for there to be any failures. - if failures: - raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs") - - # Arbitrarily use the first attribute found. - display_name = cas_response.attributes.get( - self._cas_displayname_attribute, [None] - )[0] - - return UserAttributes(localpart=localpart, display_name=display_name) - - async def grandfather_existing_users() -> Optional[str]: - # Since CAS did not always use the user_external_ids table, always - # to attempt to map to existing users. - user_id = UserID(localpart, self._hostname).to_string() - - logger.debug( - "Looking for existing account based on mapped %s", - user_id, - ) - - users = await self._store.get_users_by_id_case_insensitive(user_id) - if users: - registered_user_id = list(users.keys())[0] - logger.info("Grandfathering mapping to %s", registered_user_id) - return registered_user_id - - return None - - await self._sso_handler.complete_sso_login_request( - self.idp_id, - cas_response.username, - request, - client_redirect_url, - cas_response_to_user_attributes, - grandfather_existing_users, - registration_enabled=self._cas_enable_registration, - ) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 12a7cace55..2c4991c6e5 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py
@@ -43,7 +43,6 @@ class DeactivateAccountHandler: self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() - self._identity_handler = hs.get_identity_handler() self._profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self._server_name = hs.hostname @@ -82,7 +81,7 @@ class DeactivateAccountHandler: by_admin: Whether this change was made by an administrator. Returns: - True if identity server supports removing threepids, otherwise False. + True """ # This can only be called on the main process. @@ -96,40 +95,6 @@ class DeactivateAccountHandler: 403, "Deactivation of this user is forbidden", Codes.FORBIDDEN ) - # FIXME: Theoretically there is a race here wherein user resets - # password using threepid. - - # delete threepids first. We remove these from the IS so if this fails, - # leave the user still active so they can try again. - # Ideally we would prevent password resets and then do this in the - # background thread. - - # This will be set to false if the identity server doesn't support - # unbinding - identity_server_supports_unbinding = True - - # Attempt to unbind any known bound threepids to this account from identity - # server(s). - bound_threepids = await self.store.user_get_bound_threepids(user_id) - for medium, address in bound_threepids: - try: - result = await self._identity_handler.try_unbind_threepid( - user_id, medium, address, id_server - ) - except Exception: - # Do we want this to be a fatal error or should we carry on? - logger.exception("Failed to remove threepid from ID server") - raise SynapseError(400, "Failed to remove threepid from ID server") - - identity_server_supports_unbinding &= result - - # Remove any local threepid associations for this account. - local_threepids = await self.store.user_get_threepids(user_id) - for local_threepid in local_threepids: - await self._auth_handler.delete_local_threepid( - user_id, local_threepid.medium, local_threepid.address - ) - # delete any devices belonging to the user, which will also # delete corresponding access tokens. await self._device_handler.delete_all_devices_for_user(user_id) @@ -194,7 +159,7 @@ class DeactivateAccountHandler: by_admin, ) - return identity_server_supports_unbinding + return True async def _reject_pending_invites_and_knocks_for_user(self, user_id: str) -> None: """Reject pending invites and knocks addressed to a given user ID. diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py new file mode 100644
index 0000000000..cb2a34ff73 --- /dev/null +++ b/synapse/handlers/delayed_events.py
@@ -0,0 +1,545 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +import logging +from typing import TYPE_CHECKING, List, Optional, Set, Tuple + +from twisted.internet.interfaces import IDelayedCall + +from synapse.api.constants import EventTypes +from synapse.api.errors import ShadowBanError +from synapse.api.ratelimiting import Ratelimiter +from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME +from synapse.logging.opentracing import set_tag +from synapse.metrics import event_processing_positions +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http.delayed_events import ( + ReplicationAddedDelayedEventRestServlet, +) +from synapse.storage.databases.main.delayed_events import ( + DelayedEventDetails, + DelayID, + EventType, + StateKey, + Timestamp, + UserLocalpart, +) +from synapse.storage.databases.main.state_deltas import StateDelta +from synapse.types import ( + JsonDict, + Requester, + RoomID, + UserID, + create_requester, +) +from synapse.util.events import generate_fake_event_id +from synapse.util.metrics import Measure + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class DelayedEventsHandler: + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() + self._config = hs.config + self._clock = hs.get_clock() + self._event_creation_handler = hs.get_event_creation_handler() + self._room_member_handler = hs.get_room_member_handler() + + self._request_ratelimiter = hs.get_request_ratelimiter() + + # Ratelimiter for management of existing delayed events, + # keyed by the sending user ID & device ID. + self._delayed_event_mgmt_ratelimiter = Ratelimiter( + store=self._store, + clock=self._clock, + cfg=self._config.ratelimiting.rc_delayed_event_mgmt, + ) + + self._next_delayed_event_call: Optional[IDelayedCall] = None + + # The current position in the current_state_delta stream + self._event_pos: Optional[int] = None + + # Guard to ensure we only process event deltas one at a time + self._event_processing = False + + if hs.config.worker.worker_app is None: + self._repl_client = None + + async def _schedule_db_events() -> None: + # We kick this off to pick up outstanding work from before the last restart. + # Block until we're up to date. + await self._unsafe_process_new_event() + hs.get_notifier().add_replication_callback(self.notify_new_event) + # Kick off again (without blocking) to catch any missed notifications + # that may have fired before the callback was added. + self._clock.call_later(0, self.notify_new_event) + + # Delayed events that are already marked as processed on startup might not have been + # sent properly on the last run of the server, so unmark them to send them again. + # Caveat: this will double-send delayed events that successfully persisted, but failed + # to be removed from the DB table of delayed events. + # TODO: To avoid double-sending, scan the timeline to find which of these events were + # already sent. To do so, must store delay_ids in sent events to retrieve them later. + await self._store.unprocess_delayed_events() + + events, next_send_ts = await self._store.process_timeout_delayed_events( + self._get_current_ts() + ) + + if next_send_ts: + self._schedule_next_at(next_send_ts) + + # Can send the events in background after having awaited on marking them as processed + run_as_background_process( + "_send_events", + self._send_events, + events, + ) + + self._initialized_from_db = run_as_background_process( + "_schedule_db_events", _schedule_db_events + ) + else: + self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs) + + @property + def _is_master(self) -> bool: + return self._repl_client is None + + def notify_new_event(self) -> None: + """ + Called when there may be more state event deltas to process, + which should cancel pending delayed events for the same state. + """ + if self._event_processing: + return + + self._event_processing = True + + async def process() -> None: + try: + await self._unsafe_process_new_event() + finally: + self._event_processing = False + + run_as_background_process("delayed_events.notify_new_event", process) + + async def _unsafe_process_new_event(self) -> None: + # If self._event_pos is None then means we haven't fetched it from the DB yet + if self._event_pos is None: + self._event_pos = await self._store.get_delayed_events_stream_pos() + room_max_stream_ordering = self._store.get_room_max_stream_ordering() + if self._event_pos > room_max_stream_ordering: + # apparently, we've processed more events than exist in the database! + # this can happen if events are removed with history purge or similar. + logger.warning( + "Event stream ordering appears to have gone backwards (%i -> %i): " + "rewinding delayed events processor", + self._event_pos, + room_max_stream_ordering, + ) + self._event_pos = room_max_stream_ordering + + # Loop round handling deltas until we're up to date + while True: + with Measure(self._clock, "delayed_events_delta"): + room_max_stream_ordering = self._store.get_room_max_stream_ordering() + if self._event_pos == room_max_stream_ordering: + return + + logger.debug( + "Processing delayed events %s->%s", + self._event_pos, + room_max_stream_ordering, + ) + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( + self._event_pos, room_max_stream_ordering + ) + + logger.debug( + "Handling %d state deltas for delayed events processing", + len(deltas), + ) + await self._handle_state_deltas(deltas) + + self._event_pos = max_pos + + # Expose current event processing position to prometheus + event_processing_positions.labels("delayed_events").set(max_pos) + + await self._store.update_delayed_events_stream_pos(max_pos) + + async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None: + """ + Process current state deltas to cancel other users' pending delayed events + that target the same state. + """ + for delta in deltas: + if delta.event_id is None: + logger.debug( + "Not handling delta for deleted state: %r %r", + delta.event_type, + delta.state_key, + ) + continue + + logger.debug( + "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id + ) + + event = await self._store.get_event( + delta.event_id, check_room_id=delta.room_id, allow_rejected=True, allow_none=True + ) + + if event is None or event.rejected_reason is not None: + # This event has been rejected, so we don't want to cancel any delayed events for it. + continue + + sender = UserID.from_string(event.sender) + + next_send_ts = await self._store.cancel_delayed_state_events( + room_id=delta.room_id, + event_type=delta.event_type, + state_key=delta.state_key, + not_from_localpart=( + sender.localpart + if sender.domain == self._config.server.server_name + else "" + ), + ) + + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at_or_none(next_send_ts) + + async def add( + self, + requester: Requester, + *, + room_id: str, + event_type: str, + state_key: Optional[str], + origin_server_ts: Optional[int], + content: JsonDict, + delay: int, + ) -> str: + """ + Creates a new delayed event and schedules its delivery. + + Args: + requester: The requester of the delayed event, who will be its owner. + room_id: The ID of the room where the event should be sent to. + event_type: The type of event to be sent. + state_key: The state key of the event to be sent, or None if it is not a state event. + origin_server_ts: The custom timestamp to send the event with. + If None, the timestamp will be the actual time when the event is sent. + content: The content of the event to be sent. + delay: How long (in milliseconds) to wait before automatically sending the event. + + Returns: The ID of the added delayed event. + + Raises: + SynapseError: if the delayed event fails validation checks. + """ + # Use standard request limiter for scheduling new delayed events. + # TODO: Instead apply ratelimiting based on the scheduled send time. + # See https://github.com/element-hq/synapse/issues/18021 + await self._request_ratelimiter.ratelimit(requester) + + self._event_creation_handler.validator.validate_builder( + self._event_creation_handler.event_builder_factory.for_room_version( + await self._store.get_room_version(room_id), + { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": str(requester.user), + **({"state_key": state_key} if state_key is not None else {}), + }, + ) + ) + + creation_ts = self._get_current_ts() + + delay_id, next_send_ts = await self._store.add_delayed_event( + user_localpart=requester.user.localpart, + device_id=requester.device_id, + creation_ts=creation_ts, + room_id=room_id, + event_type=event_type, + state_key=state_key, + origin_server_ts=origin_server_ts, + content=content, + delay=delay, + ) + + if self._repl_client is not None: + # NOTE: If this throws, the delayed event will remain in the DB and + # will be picked up once the main worker gets another delayed event. + await self._repl_client( + instance_name=MAIN_PROCESS_INSTANCE_NAME, + next_send_ts=next_send_ts, + ) + elif self._next_send_ts_changed(next_send_ts): + self._schedule_next_at(next_send_ts) + + return delay_id + + def on_added(self, next_send_ts: int) -> None: + next_send_ts = Timestamp(next_send_ts) + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at(next_send_ts) + + async def cancel(self, requester: Requester, delay_id: str) -> None: + """ + Cancels the scheduled delivery of the matching delayed event. + + Args: + requester: The owner of the delayed event to act on. + delay_id: The ID of the delayed event to act on. + + Raises: + NotFoundError: if no matching delayed event could be found. + """ + assert self._is_master + await self._delayed_event_mgmt_ratelimiter.ratelimit( + requester, + (requester.user.to_string(), requester.device_id), + ) + await self._initialized_from_db + + next_send_ts = await self._store.cancel_delayed_event( + delay_id=delay_id, + user_localpart=requester.user.localpart, + ) + + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at_or_none(next_send_ts) + + async def restart(self, requester: Requester, delay_id: str) -> None: + """ + Restarts the scheduled delivery of the matching delayed event. + + Args: + requester: The owner of the delayed event to act on. + delay_id: The ID of the delayed event to act on. + + Raises: + NotFoundError: if no matching delayed event could be found. + """ + assert self._is_master + await self._delayed_event_mgmt_ratelimiter.ratelimit( + requester, + (requester.user.to_string(), requester.device_id), + ) + await self._initialized_from_db + + next_send_ts = await self._store.restart_delayed_event( + delay_id=delay_id, + user_localpart=requester.user.localpart, + current_ts=self._get_current_ts(), + ) + + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at(next_send_ts) + + async def send(self, requester: Requester, delay_id: str) -> None: + """ + Immediately sends the matching delayed event, instead of waiting for its scheduled delivery. + + Args: + requester: The owner of the delayed event to act on. + delay_id: The ID of the delayed event to act on. + + Raises: + NotFoundError: if no matching delayed event could be found. + """ + assert self._is_master + # Use standard request limiter for sending delayed events on-demand, + # as an on-demand send is similar to sending a regular event. + await self._request_ratelimiter.ratelimit(requester) + await self._initialized_from_db + + event, next_send_ts = await self._store.process_target_delayed_event( + delay_id=delay_id, + user_localpart=requester.user.localpart, + ) + + if self._next_send_ts_changed(next_send_ts): + self._schedule_next_at_or_none(next_send_ts) + + await self._send_event( + DelayedEventDetails( + delay_id=DelayID(delay_id), + user_localpart=UserLocalpart(requester.user.localpart), + room_id=event.room_id, + type=event.type, + state_key=event.state_key, + origin_server_ts=event.origin_server_ts, + content=event.content, + device_id=event.device_id, + ) + ) + + async def _send_on_timeout(self) -> None: + self._next_delayed_event_call = None + + events, next_send_ts = await self._store.process_timeout_delayed_events( + self._get_current_ts() + ) + + if next_send_ts: + self._schedule_next_at(next_send_ts) + + await self._send_events(events) + + async def _send_events(self, events: List[DelayedEventDetails]) -> None: + sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set() + for event in events: + if event.state_key is not None: + state_info = (event.room_id, event.type, event.state_key) + if state_info in sent_state: + continue + else: + state_info = None + try: + # TODO: send in background if message event or non-conflicting state event + await self._send_event(event) + if state_info is not None: + sent_state.add(state_info) + except Exception: + logger.exception("Failed to send delayed event") + + for room_id, event_type, state_key in sent_state: + await self._store.delete_processed_delayed_state_events( + room_id=str(room_id), + event_type=event_type, + state_key=state_key, + ) + + def _schedule_next_at_or_none(self, next_send_ts: Optional[Timestamp]) -> None: + if next_send_ts is not None: + self._schedule_next_at(next_send_ts) + elif self._next_delayed_event_call is not None: + self._next_delayed_event_call.cancel() + self._next_delayed_event_call = None + + def _schedule_next_at(self, next_send_ts: Timestamp) -> None: + delay = next_send_ts - self._get_current_ts() + delay_sec = delay / 1000 if delay > 0 else 0 + + if self._next_delayed_event_call is None: + self._next_delayed_event_call = self._clock.call_later( + delay_sec, + run_as_background_process, + "_send_on_timeout", + self._send_on_timeout, + ) + else: + self._next_delayed_event_call.reset(delay_sec) + + async def get_all_for_user(self, requester: Requester) -> List[JsonDict]: + """Return all pending delayed events requested by the given user.""" + await self._delayed_event_mgmt_ratelimiter.ratelimit( + requester, + (requester.user.to_string(), requester.device_id), + ) + return await self._store.get_all_delayed_events_for_user( + requester.user.localpart + ) + + async def _send_event( + self, + event: DelayedEventDetails, + txn_id: Optional[str] = None, + ) -> None: + user_id = UserID(event.user_localpart, self._config.server.server_name) + user_id_str = user_id.to_string() + # Create a new requester from what data is currently available + requester = create_requester( + user_id, + is_guest=await self._store.is_guest(user_id_str), + device_id=event.device_id, + ) + + try: + if event.state_key is not None and event.type == EventTypes.Member: + membership = event.content.get("membership") + assert membership is not None + event_id, _ = await self._room_member_handler.update_membership( + requester, + target=UserID.from_string(event.state_key), + room_id=event.room_id.to_string(), + action=membership, + content=event.content, + origin_server_ts=event.origin_server_ts, + ) + else: + event_dict: JsonDict = { + "type": event.type, + "content": event.content, + "room_id": event.room_id.to_string(), + "sender": user_id_str, + } + + if event.origin_server_ts is not None: + event_dict["origin_server_ts"] = event.origin_server_ts + + if event.state_key is not None: + event_dict["state_key"] = event.state_key + + ( + sent_event, + _, + ) = await self._event_creation_handler.create_and_send_nonmember_event( + requester, + event_dict, + txn_id=txn_id, + ) + event_id = sent_event.event_id + except ShadowBanError: + event_id = generate_fake_event_id() + finally: + # TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure + try: + await self._store.delete_processed_delayed_event( + event.delay_id, event.user_localpart + ) + except Exception: + logger.exception("Failed to delete processed delayed event") + + set_tag("event_id", event_id) + + def _get_current_ts(self) -> Timestamp: + return Timestamp(self._clock.time_msec()) + + def _next_send_ts_changed(self, next_send_ts: Optional[Timestamp]) -> bool: + # The DB alone knows if the next send time changed after adding/modifying + # a delayed event, but if we were to ever miss updating our delayed call's + # firing time, we may miss other updates. So, keep track of changes to the + # the next send time here instead of in the DB. + cached_next_send_ts = ( + int(self._next_delayed_event_call.getTime() * 1000) + if self._next_delayed_event_call is not None + else None + ) + return next_send_ts != cached_next_send_ts diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 4fc6fcd7ae..f8b547bbed 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -20,10 +20,21 @@ # # import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Set, Tuple +from threading import Lock +from typing import ( + TYPE_CHECKING, + AbstractSet, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, +) from synapse.api import errors -from synapse.api.constants import EduTypes, EventTypes +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import ( Codes, FederationDeniedError, @@ -38,6 +49,8 @@ from synapse.metrics.background_process_metrics import ( wrap_as_background_process, ) from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo +from synapse.storage.databases.main.roommember import EventIdMembership +from synapse.storage.databases.main.state_deltas import StateDelta from synapse.types import ( DeviceListUpdates, JsonDict, @@ -151,6 +164,8 @@ class DeviceWorkerHandler: raise errors.NotFoundError() ips = await self.store.get_last_client_ip_by_device(user_id, device_id) + + device = dict(device) _update_device_from_client_ips(device, ips) set_tag("device", str(device)) @@ -211,7 +226,6 @@ class DeviceWorkerHandler: return changed @trace - @measure_func("device.get_user_ids_changed") @cancellable async def get_user_ids_changed( self, user_id: str, from_token: StreamToken @@ -222,129 +236,113 @@ class DeviceWorkerHandler: set_tag("user_id", user_id) set_tag("from_token", str(from_token)) - now_room_key = self.store.get_room_max_token() - - room_ids = await self.store.get_rooms_for_user(user_id) - changed = await self.get_device_changes_in_shared_rooms( - user_id, room_ids, from_token - ) + now_token = self._event_sources.get_current_token() - # Then work out if any users have since joined - rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) + # We need to work out all the different membership changes for the user + # and user they share a room with, to pass to + # `generate_sync_entry_for_device_list`. See its docstring for details + # on the data required. - member_events = await self.store.get_membership_changes_for_user( - user_id, from_token.room_key, now_room_key - ) - rooms_changed.update(event.room_id for event in member_events) - - stream_ordering = from_token.room_key.stream - - possibly_changed = set(changed) - possibly_left = set() - for room_id in rooms_changed: - # Check if the forward extremities have changed. If not then we know - # the current state won't have changed, and so we can skip this room. - try: - if not await self.store.have_room_forward_extremities_changed_since( - room_id, stream_ordering - ): - continue - except errors.StoreError: - pass + joined_room_ids = await self.store.get_rooms_for_user(user_id) - current_state_ids = await self._state_storage.get_current_state_ids( - room_id, await_full_state=False + # Get the set of rooms that the user has joined/left + membership_changes = ( + await self.store.get_current_state_delta_membership_changes_for_user( + user_id, from_key=from_token.room_key, to_key=now_token.room_key ) + ) - # The user may have left the room - # TODO: Check if they actually did or if we were just invited. - if room_id not in room_ids: - for etype, state_key in current_state_ids.keys(): - if etype != EventTypes.Member: - continue - possibly_left.add(state_key) - continue - - # Fetch the current state at the time. - try: - event_ids = await self.store.get_forward_extremities_for_room_at_stream_ordering( - room_id, stream_ordering=stream_ordering - ) - except errors.StoreError: - # we have purged the stream_ordering index since the stream - # ordering: treat it the same as a new room - event_ids = [] - - # special-case for an empty prev state: include all members - # in the changed list - if not event_ids: - log_kv( - {"event": "encountered empty previous state", "room_id": room_id} - ) - for etype, state_key in current_state_ids.keys(): - if etype != EventTypes.Member: - continue - possibly_changed.add(state_key) - continue - - current_member_id = current_state_ids.get((EventTypes.Member, user_id)) - if not current_member_id: + # Check for newly joined or left rooms. We need to make sure that we add + # to newly joined in the case membership goes from join -> leave -> join + # again. + newly_joined_rooms: Set[str] = set() + newly_left_rooms: Set[str] = set() + for change in membership_changes: + # We check for changes in "joinedness", i.e. if the membership has + # changed to or from JOIN. + if change.membership == Membership.JOIN: + if change.prev_membership != Membership.JOIN: + newly_joined_rooms.add(change.room_id) + newly_left_rooms.discard(change.room_id) + elif change.prev_membership == Membership.JOIN: + newly_joined_rooms.discard(change.room_id) + newly_left_rooms.add(change.room_id) + + # We now work out if any other users have since joined or left the rooms + # the user is currently in. + + # List of membership changes per room + room_to_deltas: Dict[str, List[StateDelta]] = {} + # The set of event IDs of membership events (so we can fetch their + # associated membership). + memberships_to_fetch: Set[str] = set() + + # TODO: Only pull out membership events? + state_changes = await self.store.get_current_state_deltas_for_rooms( + joined_room_ids, from_token=from_token.room_key, to_token=now_token.room_key + ) + for delta in state_changes: + if delta.event_type != EventTypes.Member: continue - # mapping from event_id -> state_dict - prev_state_ids = await self._state_storage.get_state_ids_for_events( - event_ids, - await_full_state=False, + room_to_deltas.setdefault(delta.room_id, []).append(delta) + if delta.event_id: + memberships_to_fetch.add(delta.event_id) + if delta.prev_event_id: + memberships_to_fetch.add(delta.prev_event_id) + + # Fetch all the memberships for the membership events + event_id_to_memberships: Mapping[str, Optional[EventIdMembership]] = {} + if memberships_to_fetch: + event_id_to_memberships = await self.store.get_membership_from_event_ids( + memberships_to_fetch ) - # Check if we've joined the room? If so we just blindly add all the users to - # the "possibly changed" users. - for state_dict in prev_state_ids.values(): - member_event = state_dict.get((EventTypes.Member, user_id), None) - if not member_event or member_event != current_member_id: - for etype, state_key in current_state_ids.keys(): - if etype != EventTypes.Member: - continue - possibly_changed.add(state_key) - break - - # If there has been any change in membership, include them in the - # possibly changed list. We'll check if they are joined below, - # and we're not toooo worried about spuriously adding users. - for key, event_id in current_state_ids.items(): - etype, state_key = key - if etype != EventTypes.Member: - continue - - # check if this member has changed since any of the extremities - # at the stream_ordering, and add them to the list if so. - for state_dict in prev_state_ids.values(): - prev_event_id = state_dict.get(key, None) - if not prev_event_id or prev_event_id != event_id: - if state_key != user_id: - possibly_changed.add(state_key) - break - - if possibly_changed or possibly_left: - possibly_joined = possibly_changed - possibly_left = possibly_changed | possibly_left - - # Double check if we still share rooms with the given user. - users_rooms = await self.store.get_rooms_for_users(possibly_left) - for changed_user_id, entries in users_rooms.items(): - if any(rid in room_ids for rid in entries): - possibly_left.discard(changed_user_id) - else: - possibly_joined.discard(changed_user_id) - - else: - possibly_joined = set() - possibly_left = set() + joined_invited_knocked = ( + Membership.JOIN, + Membership.INVITE, + Membership.KNOCK, + ) - device_list_updates = DeviceListUpdates( - changed=possibly_joined, - left=possibly_left, + # We now want to find any user that have newly joined/invited/knocked, + # or newly left, similarly to above. + newly_joined_or_invited_or_knocked_users: Set[str] = set() + newly_left_users: Set[str] = set() + for _, deltas in room_to_deltas.items(): + for delta in deltas: + # Get the prev/new memberships for the delta + new_membership = None + prev_membership = None + if delta.event_id: + m = event_id_to_memberships.get(delta.event_id) + if m is not None: + new_membership = m.membership + if delta.prev_event_id: + m = event_id_to_memberships.get(delta.prev_event_id) + if m is not None: + prev_membership = m.membership + + # Check if a user has newly joined/invited/knocked, or left. + if new_membership in joined_invited_knocked: + if prev_membership not in joined_invited_knocked: + newly_joined_or_invited_or_knocked_users.add(delta.state_key) + newly_left_users.discard(delta.state_key) + elif prev_membership in joined_invited_knocked: + newly_joined_or_invited_or_knocked_users.discard(delta.state_key) + newly_left_users.add(delta.state_key) + + # Now we actually calculate the device list entry with the information + # calculated above. + device_list_updates = await self.generate_sync_entry_for_device_list( + user_id=user_id, + since_token=from_token, + now_token=now_token, + joined_room_ids=joined_room_ids, + newly_joined_rooms=newly_joined_rooms, + newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users, + newly_left_rooms=newly_left_rooms, + newly_left_users=newly_left_users, ) log_kv( @@ -356,6 +354,87 @@ class DeviceWorkerHandler: return device_list_updates + async def generate_sync_entry_for_device_list( + self, + user_id: str, + since_token: StreamToken, + now_token: StreamToken, + joined_room_ids: AbstractSet[str], + newly_joined_rooms: AbstractSet[str], + newly_joined_or_invited_or_knocked_users: AbstractSet[str], + newly_left_rooms: AbstractSet[str], + newly_left_users: AbstractSet[str], + ) -> DeviceListUpdates: + """Generate the DeviceListUpdates section of sync + + Args: + sync_result_builder + newly_joined_rooms: Set of rooms user has joined since previous sync + newly_joined_or_invited_or_knocked_users: Set of users that have joined, + been invited to a room or are knocking on a room since + previous sync. + newly_left_rooms: Set of rooms user has left since previous sync + newly_left_users: Set of users that have left a room we're in since + previous sync + """ + # Take a copy since these fields will be mutated later. + newly_joined_or_invited_or_knocked_users = set( + newly_joined_or_invited_or_knocked_users + ) + newly_left_users = set(newly_left_users) + + # We want to figure out what user IDs the client should refetch + # device keys for, and which users we aren't going to track changes + # for anymore. + # + # For the first step we check: + # a. if any users we share a room with have updated their devices, + # and + # b. we also check if we've joined any new rooms, or if a user has + # joined a room we're in. + # + # For the second step we just find any users we no longer share a + # room with by looking at all users that have left a room plus users + # that were in a room we've left. + + users_that_have_changed = set() + + # Step 1a, check for changes in devices of users we share a room + # with + users_that_have_changed = await self.get_device_changes_in_shared_rooms( + user_id, + joined_room_ids, + from_token=since_token, + now_token=now_token, + ) + + # Step 1b, check for newly joined rooms + for room_id in newly_joined_rooms: + joined_users = await self.store.get_users_in_room(room_id) + newly_joined_or_invited_or_knocked_users.update(joined_users) + + # TODO: Check that these users are actually new, i.e. either they + # weren't in the previous sync *or* they left and rejoined. + users_that_have_changed.update(newly_joined_or_invited_or_knocked_users) + + user_signatures_changed = await self.store.get_users_whose_signatures_changed( + user_id, since_token.device_list_key + ) + users_that_have_changed.update(user_signatures_changed) + + # Now find users that we no longer track + for room_id in newly_left_rooms: + left_users = await self.store.get_users_in_room(room_id) + newly_left_users.update(left_users) + + # Remove any users that we still share a room with. + left_users_rooms = await self.store.get_rooms_for_users(newly_left_users) + for user_id, entries in left_users_rooms.items(): + if any(rid in joined_room_ids for rid in entries): + newly_left_users.discard(user_id) + + return DeviceListUpdates(changed=users_that_have_changed, left=newly_left_users) + async def on_federation_query_user_devices(self, user_id: str) -> JsonDict: if not self.hs.is_mine(UserID.from_string(user_id)): raise SynapseError(400, "User is not hosted on this homeserver") @@ -653,6 +732,40 @@ class DeviceHandler(DeviceWorkerHandler): await self.notify_device_update(user_id, device_ids) + async def upsert_device( + self, user_id: str, device_id: str, display_name: Optional[str] = None + ) -> bool: + """Create or update a device + + Args: + user_id: The user to update devices of. + device_id: The device to update. + display_name: The new display name for this device. + + Returns: + True if the device was created, False if it was updated. + + """ + + # Reject a new displayname which is too long. + self._check_device_name_length(display_name) + + created = await self.store.store_device( + user_id, + device_id, + initial_device_display_name=display_name, + ) + + if not created: + await self.store.update_device( + user_id, + device_id, + new_display_name=display_name, + ) + + await self.notify_device_update(user_id, [device_id]) + return created + async def update_device(self, user_id: str, device_id: str, content: dict) -> None: """Update the given device @@ -1125,7 +1238,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): ) # Attempt to resync out of sync device lists every 30s. - self._resync_retry_in_progress = False + self._resync_retry_lock = Lock() self.clock.looping_call( run_as_background_process, 30 * 1000, @@ -1307,13 +1420,10 @@ class DeviceListUpdater(DeviceListWorkerUpdater): """Retry to resync device lists that are out of sync, except if another retry is in progress. """ - if self._resync_retry_in_progress: + # If the lock can not be acquired we want to always return immediately instead of blocking here + if not self._resync_retry_lock.acquire(blocking=False): return - try: - # Prevent another call of this function to retry resyncing device lists so - # we don't send too many requests. - self._resync_retry_in_progress = True # Get all of the users that need resyncing. need_resync = await self.store.get_user_ids_requiring_device_list_resync() @@ -1354,8 +1464,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): e, ) finally: - # Allow future calls to retry resyncinc out of sync device lists. - self._resync_retry_in_progress = False + self._resync_retry_lock.release() async def multi_user_device_resync( self, user_ids: List[str], mark_failed_as_stale: bool = True diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index ad2b0f5fcc..48c7d411d5 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py
@@ -21,9 +21,7 @@ import logging import string -from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence - -from typing_extensions import Literal +from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Sequence from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.errors import ( @@ -265,9 +263,9 @@ class DirectoryHandler: async def get_association(self, room_alias: RoomAlias) -> JsonDict: room_id = None if self.hs.is_mine(room_alias): - result: Optional[RoomAliasMapping] = ( - await self.get_association_from_room_alias(room_alias) - ) + result: Optional[ + RoomAliasMapping + ] = await self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id @@ -512,11 +510,9 @@ class DirectoryHandler: raise SynapseError(403, "Not allowed to publish room") # Check if publishing is blocked by a third party module - allowed_by_third_party_rules = ( - await ( - self._third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility - ) + allowed_by_third_party_rules = await ( + self._third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility ) ) if not allowed_by_third_party_rules: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index f78e66ad0a..6171aaf29f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -39,6 +39,8 @@ from synapse.replication.http.devices import ReplicationUploadKeysForUserRestSer from synapse.types import ( JsonDict, JsonMapping, + ScheduledTask, + TaskStatus, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -70,6 +72,7 @@ class E2eKeysHandler: self.is_mine = hs.is_mine self.clock = hs.get_clock() self._worker_lock_handler = hs.get_worker_locks_handler() + self._task_scheduler = hs.get_task_scheduler() federation_registry = hs.get_federation_registry() @@ -116,6 +119,10 @@ class E2eKeysHandler: hs.config.experimental.msc3984_appservice_key_query ) + self._task_scheduler.register_action( + self._delete_old_one_time_keys_task, "delete_old_otks" + ) + @trace @cancellable async def query_devices( @@ -151,7 +158,37 @@ class E2eKeysHandler: the number of in-flight queries at a time. """ async with self._query_devices_linearizer.queue((from_user_id, from_device_id)): - device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {}) + + async def filter_device_key_query( + query: Dict[str, List[str]], + ) -> Dict[str, List[str]]: + if not self.config.experimental.msc4263_limit_key_queries_to_users_who_share_rooms: + # Only ignore invalid user IDs, which is the same behaviour as if + # the user existed but had no keys. + return { + user_id: v + for user_id, v in query.items() + if UserID.is_valid(user_id) + } + + # Strip invalid user IDs and user IDs the requesting user does not share rooms with. + valid_user_ids = [ + user_id for user_id in query.keys() if UserID.is_valid(user_id) + ] + allowed_user_ids = set( + await self.store.do_users_share_a_room_joined_or_invited( + from_user_id, valid_user_ids + ) + ) + return { + user_id: v + for user_id, v in query.items() + if user_id in allowed_user_ids + } + + device_keys_query: Dict[str, List[str]] = await filter_device_key_query( + query_body.get("device_keys", {}) + ) # separate users by domain. # make a map from domain to user_id to device_ids @@ -159,11 +196,6 @@ class E2eKeysHandler: remote_queries = {} for user_id, device_ids in device_keys_query.items(): - if not UserID.is_valid(user_id): - # Ignore invalid user IDs, which is the same behaviour as if - # the user existed but had no keys. - continue - # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): local_query[user_id] = device_ids @@ -615,7 +647,7 @@ class E2eKeysHandler: 3. Attempt to fetch fallback keys from the database. Args: - local_query: An iterable of tuples of (user ID, device ID, algorithm). + local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys). always_include_fallback_keys: True to always include fallback keys. Returns: @@ -1156,7 +1188,7 @@ class E2eKeysHandler: devices = devices[user_id] except SynapseError as e: failure = _exception_to_failure(e) - failures[user_id] = {device: failure for device in signatures.keys()} + failures[user_id] = dict.fromkeys(signatures.keys(), failure) return signature_list, failures for device_id, device in signatures.items(): @@ -1296,7 +1328,7 @@ class E2eKeysHandler: except SynapseError as e: failure = _exception_to_failure(e) for user, devicemap in signatures.items(): - failures[user] = {device_id: failure for device_id in devicemap.keys()} + failures[user] = dict.fromkeys(devicemap.keys(), failure) return signature_list, failures for target_user, devicemap in signatures.items(): @@ -1337,9 +1369,7 @@ class E2eKeysHandler: # other devices were signed -- mark those as failures logger.debug("upload signature: too many devices specified") failure = _exception_to_failure(NotFoundError("Unknown device")) - failures[target_user] = { - device: failure for device in other_devices - } + failures[target_user] = dict.fromkeys(other_devices, failure) if user_signing_key_id in master_key.get("signatures", {}).get( user_id, {} @@ -1360,9 +1390,7 @@ class E2eKeysHandler: except SynapseError as e: failure = _exception_to_failure(e) if device_id is None: - failures[target_user] = { - device_id: failure for device_id in devicemap.keys() - } + failures[target_user] = dict.fromkeys(devicemap.keys(), failure) else: failures.setdefault(target_user, {})[device_id] = failure @@ -1574,6 +1602,45 @@ class E2eKeysHandler: return True return False + async def _delete_old_one_time_keys_task( + self, task: ScheduledTask + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + """Scheduler task to delete old one time keys. + + Until Synapse 1.119, Synapse used to issue one-time-keys in a random order, leading to the possibility + that it could still have old OTKs that the client has dropped. This task is scheduled exactly once + by a database schema delta file, and it clears out old one-time-keys that look like they came from libolm. + """ + last_user = task.result.get("from_user", "") if task.result else "" + while True: + # We process users in batches of 100 + users, rowcount = await self.store.delete_old_otks_for_next_user_batch( + last_user, 100 + ) + if len(users) == 0: + # We're done! + return TaskStatus.COMPLETE, None, None + + logger.debug( + "Deleted %i old one-time-keys for users '%s'..'%s'", + rowcount, + users[0], + users[-1], + ) + last_user = users[-1] + + # Store our progress + await self._task_scheduler.update_task( + task.id, result={"from_user": last_user} + ) + + # Sleep a little before doing the next user. + # + # matrix.org has about 15M users in the e2e_one_time_keys_json table + # (comprising 20M devices). We want this to take about a week, so we need + # to do about one batch of 100 users every 4 seconds. + await self.clock.sleep(4) + def _check_cross_signing_key( key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f397911f28..623fd33f13 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py
@@ -20,9 +20,7 @@ # import logging -from typing import TYPE_CHECKING, Dict, Optional, cast - -from typing_extensions import Literal +from typing import TYPE_CHECKING, Dict, Literal, Optional, cast from synapse.api.errors import ( Codes, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 299588e476..ff751d25f6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -78,6 +78,7 @@ from synapse.replication.http.federation import ( ReplicationStoreRoomOnOutlierMembershipRestServlet, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.storage.invite_rule import InviteRule from synapse.types import JsonDict, StrCollection, get_domain_from_id from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer @@ -210,7 +211,7 @@ class FederationHandler: @tag_args async def maybe_backfill( self, room_id: str, current_depth: int, limit: int, record_time: bool = True - ) -> bool: + ) -> None: """Checks the database to see if we should backfill before paginating, and if so do. @@ -224,8 +225,6 @@ class FederationHandler: should back paginate. record_time: Whether to record the time it takes to backfill. - Returns: - True if we actually tried to backfill something, otherwise False. """ # Starting the processing time here so we can include the room backfill # linearizer lock queue in the timing @@ -251,7 +250,7 @@ class FederationHandler: limit: int, *, processing_start_time: Optional[int], - ) -> bool: + ) -> None: """ Checks whether the `current_depth` is at or approaching any backfill points in the room and if so, will backfill. We only care about @@ -325,7 +324,7 @@ class FederationHandler: limit=1, ) if not have_later_backfill_points: - return False + return None logger.debug( "_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points." @@ -345,15 +344,15 @@ class FederationHandler: ) # We return `False` because we're backfilling in the background and there is # no new events immediately for the caller to know about yet. - return False + return None # Even after recursing with `MAX_DEPTH`, we didn't find any # backward extremities to backfill from. if not sorted_backfill_points: logger.debug( - "_maybe_backfill_inner: Not backfilling as no backward extremeties found." + "_maybe_backfill_inner: Not backfilling as no backward extremities found." ) - return False + return None # If we're approaching an extremity we trigger a backfill, otherwise we # no-op. @@ -372,7 +371,7 @@ class FederationHandler: current_depth, limit, ) - return False + return None # For performance's sake, we only want to paginate from a particular extremity # if we can actually see the events we'll get. Otherwise, we'd just spend a lot @@ -440,7 +439,7 @@ class FederationHandler: logger.debug( "_maybe_backfill_inner: found no extremities which would be visible" ) - return False + return None logger.debug( "_maybe_backfill_inner: extremities_to_request %s", extremities_to_request @@ -463,7 +462,7 @@ class FederationHandler: ) ) - async def try_backfill(domains: StrCollection) -> bool: + async def try_backfill(domains: StrCollection) -> None: # TODO: Should we try multiple of these at a time? # Number of contacted remote homeservers that have denied our backfill @@ -486,7 +485,7 @@ class FederationHandler: # If this succeeded then we probably already have the # appropriate stuff. # TODO: We can probably do something more intelligent here. - return True + return None except NotRetryingDestination as e: logger.info("_maybe_backfill_inner: %s", e) continue @@ -510,7 +509,7 @@ class FederationHandler: ) denied_count += 1 if denied_count >= max_denied_count: - return False + return None continue logger.info("Failed to backfill from %s because %s", dom, e) @@ -526,7 +525,7 @@ class FederationHandler: ) denied_count += 1 if denied_count >= max_denied_count: - return False + return None continue logger.info("Failed to backfill from %s because %s", dom, e) @@ -538,7 +537,7 @@ class FederationHandler: logger.exception("Failed to backfill from %s because %s", dom, e) continue - return False + return None # If we have the `processing_start_time`, then we can make an # observation. We wouldn't have the `processing_start_time` in the case @@ -550,14 +549,9 @@ class FederationHandler: (processing_end_time - processing_start_time) / 1000 ) - success = await try_backfill(likely_domains) - if success: - return True - # TODO: we could also try servers which were previously in the room, but # are no longer. - - return False + return await try_backfill(likely_domains) async def send_invite(self, target_host: str, event: EventBase) -> EventBase: """Sends the invite to the remote server for signing. @@ -880,6 +874,9 @@ class FederationHandler: if stripped_room_state is None: raise KeyError("Missing 'knock_room_state' field in send_knock response") + if not isinstance(stripped_room_state, list): + raise TypeError("'knock_room_state' has wrong type") + event.unsigned["knock_room_state"] = stripped_room_state context = EventContext.for_outlier(self._storage_controllers) @@ -1001,11 +998,11 @@ class FederationHandler: ) if include_auth_user_id: - event_content[EventContentFields.AUTHORISING_USER] = ( - await self._event_auth_handler.get_user_which_could_invite( - room_id, - state_ids, - ) + event_content[ + EventContentFields.AUTHORISING_USER + ] = await self._event_auth_handler.get_user_which_could_invite( + room_id, + state_ids, ) builder = self.event_builder_factory.for_room_version( @@ -1086,6 +1083,20 @@ class FederationHandler: if event.state_key == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") + # check the invitee's configuration and apply rules + invite_config = await self.store.get_invite_config_for_user(event.state_key) + rule = invite_config.get_invite_rule(event.sender) + if rule == InviteRule.BLOCK: + logger.info( + f"Automatically rejecting invite from {event.sender} due to the invite filtering rules of {event.state_key}" + ) + raise SynapseError( + 403, + "You are not permitted to invite this user.", + errcode=Codes.INVITE_BLOCKED, + ) + # InviteRule.IGNORE is handled at the sync layer + # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() # We don't rate limit based on room ID, as that should be done by @@ -1309,9 +1320,9 @@ class FederationHandler: if state_key is not None: # the event was not rejected (get_event raises a NotFoundError for rejected # events) so the state at the event should include the event itself. - assert ( - state_map.get((event.type, state_key)) == event.event_id - ), "State at event did not include event itself" + assert state_map.get((event.type, state_key)) == event.event_id, ( + "State at event did not include event itself" + ) # ... but we need the state *before* that event if "replaces_state" in event.unsigned: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index c85deaed56..1e738f484f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py
@@ -151,6 +151,8 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self._store = hs.get_datastores().main + self._state_store = hs.get_datastores().state + self._state_deletion_store = hs.get_datastores().state_deletion self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -580,7 +582,9 @@ class FederationEventHandler: room_version.identifier, state_maps_to_resolve, event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_deletion_store + ), ) ) else: @@ -1179,7 +1183,9 @@ class FederationEventHandler: room_version, state_maps, event_map={event_id: event}, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_deletion_store + ), ) except Exception as e: @@ -1874,7 +1880,9 @@ class FederationEventHandler: room_version, [local_state_id_map, claimed_auth_events_id_map], event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_deletion_store + ), ) ) else: @@ -2014,7 +2022,9 @@ class FederationEventHandler: room_version, state_sets, event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_deletion_store + ), ) ) else: @@ -2272,8 +2282,9 @@ class FederationEventHandler: event_and_contexts, backfilled=backfilled ) - # After persistence we always need to notify replication there may - # be new data. + # After persistence, we never notify clients (wake up `/sync` streams) about + # backfilled events but it's important to let all the workers know about any + # new event (backfilled or not) because TODO self._notifier.notify_replication() if self._ephemeral_messages_enabled: diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py deleted file mode 100644
index cb31d65aa9..0000000000 --- a/synapse/handlers/identity.py +++ /dev/null
@@ -1,811 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2017 Vector Creations Ltd -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -"""Utilities for interacting with Identity Servers""" -import logging -import urllib.parse -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple - -import attr - -from synapse.api.errors import ( - CodeMessageException, - Codes, - HttpResponseException, - SynapseError, -) -from synapse.api.ratelimiting import Ratelimiter -from synapse.http import RequestTimedOutError -from synapse.http.client import SimpleHttpClient -from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, Requester -from synapse.util import json_decoder -from synapse.util.hash import sha256_and_url_safe_base64 -from synapse.util.stringutils import ( - assert_valid_client_secret, - random_string, - valid_id_server_location, -) - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - -id_server_scheme = "https://" - - -class IdentityHandler: - def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastores().main - # An HTTP client for contacting trusted URLs. - self.http_client = SimpleHttpClient(hs) - # An HTTP client for contacting identity servers specified by clients. - self._http_client = SimpleHttpClient( - hs, - ip_blocklist=hs.config.server.federation_ip_range_blocklist, - ip_allowlist=hs.config.server.federation_ip_range_allowlist, - ) - self.federation_http_client = hs.get_federation_http_client() - self.hs = hs - - self._web_client_location = hs.config.email.invite_client_location - - # Ratelimiters for `/requestToken` endpoints. - self._3pid_validation_ratelimiter_ip = Ratelimiter( - store=self.store, - clock=hs.get_clock(), - cfg=hs.config.ratelimiting.rc_3pid_validation, - ) - self._3pid_validation_ratelimiter_address = Ratelimiter( - store=self.store, - clock=hs.get_clock(), - cfg=hs.config.ratelimiting.rc_3pid_validation, - ) - - async def ratelimit_request_token_requests( - self, - request: SynapseRequest, - medium: str, - address: str, - ) -> None: - """Used to ratelimit requests to `/requestToken` by IP and address. - - Args: - request: The associated request - medium: The type of threepid, e.g. "msisdn" or "email" - address: The actual threepid ID, e.g. the phone number or email address - """ - - await self._3pid_validation_ratelimiter_ip.ratelimit( - None, (medium, request.getClientAddress().host) - ) - await self._3pid_validation_ratelimiter_address.ratelimit( - None, (medium, address) - ) - - async def threepid_from_creds( - self, id_server: str, creds: Dict[str, str] - ) -> Optional[JsonDict]: - """ - Retrieve and validate a threepid identifier from a "credentials" dictionary against a - given identity server - - Args: - id_server: The identity server to validate 3PIDs against. Must be a - complete URL including the protocol (http(s)://) - creds: Dictionary containing the following keys: - * client_secret|clientSecret: A unique secret str provided by the client - * sid: The ID of the validation session - - Returns: - A dictionary consisting of response params to the /getValidated3pid - endpoint of the Identity Service API, or None if the threepid was not found - """ - client_secret = creds.get("client_secret") or creds.get("clientSecret") - if not client_secret: - raise SynapseError( - 400, "Missing param client_secret in creds", errcode=Codes.MISSING_PARAM - ) - assert_valid_client_secret(client_secret) - - session_id = creds.get("sid") - if not session_id: - raise SynapseError( - 400, "Missing param session_id in creds", errcode=Codes.MISSING_PARAM - ) - - query_params = {"sid": session_id, "client_secret": client_secret} - - url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid" - - try: - data = await self.http_client.get_json(url, query_params) - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - except HttpResponseException as e: - logger.info( - "%s returned %i for threepid validation for: %s", - id_server, - e.code, - creds, - ) - return None - - # Old versions of Sydent return a 200 http code even on a failed validation - # check. Thus, in addition to the HttpResponseException check above (which - # checks for non-200 errors), we need to make sure validation_session isn't - # actually an error, identified by the absence of a "medium" key - # See https://github.com/matrix-org/sydent/issues/215 for details - if "medium" in data: - return data - - logger.info("%s reported non-validated threepid: %s", id_server, creds) - return None - - async def bind_threepid( - self, - client_secret: str, - sid: str, - mxid: str, - id_server: str, - id_access_token: str, - ) -> JsonDict: - """Bind a 3PID to an identity server - - Args: - client_secret: A unique secret provided by the client - sid: The ID of the validation session - mxid: The MXID to bind the 3PID to - id_server: The domain of the identity server to query - id_access_token: The access token to authenticate to the identity - server with - - Raises: - SynapseError: On any of the following conditions - - the supplied id_server is not a valid identity server name - - we failed to contact the supplied identity server - - Returns: - The response from the identity server - """ - logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server) - - if not valid_id_server_location(id_server): - raise SynapseError( - 400, - "id_server must be a valid hostname with optional port and path components", - ) - - bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid} - bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,) - headers = {"Authorization": create_id_access_token_header(id_access_token)} - - try: - # Use the blacklisting http client as this call is only to identity servers - # provided by a client - data = await self._http_client.post_json_get_json( - bind_url, bind_data, headers=headers - ) - - # Remember where we bound the threepid - await self.store.add_user_bound_threepid( - user_id=mxid, - medium=data["medium"], - address=data["address"], - id_server=id_server, - ) - - return data - except HttpResponseException as e: - logger.error("3PID bind failed with Matrix error: %r", e) - raise e.to_synapse_error() - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - except CodeMessageException as e: - data = json_decoder.decode(e.msg) # XXX WAT? - return data - - async def try_unbind_threepid( - self, mxid: str, medium: str, address: str, id_server: Optional[str] - ) -> bool: - """Attempt to remove a 3PID from one or more identity servers. - - Args: - mxid: Matrix user ID of binding to be removed - medium: The medium of the third-party ID. - address: The address of the third-party ID. - id_server: An identity server to attempt to unbind from. If None, - attempt to remove the association from all identity servers - known to potentially have it. - - Raises: - SynapseError: If we failed to contact one or more identity servers. - - Returns: - True on success, otherwise False if the identity server doesn't - support unbinding (or no identity server to contact was found). - """ - if id_server: - id_servers = [id_server] - else: - id_servers = await self.store.get_id_servers_user_bound( - mxid, medium, address - ) - - # We don't know where to unbind, so we don't have a choice but to return - if not id_servers: - return False - - changed = True - for id_server in id_servers: - changed &= await self._try_unbind_threepid_with_id_server( - mxid, medium, address, id_server - ) - - return changed - - async def _try_unbind_threepid_with_id_server( - self, mxid: str, medium: str, address: str, id_server: str - ) -> bool: - """Removes a binding from an identity server - - Args: - mxid: Matrix user ID of binding to be removed - medium: The medium of the third-party ID - address: The address of the third-party ID - id_server: Identity server to unbind from - - Raises: - SynapseError: On any of the following conditions - - the supplied id_server is not a valid identity server name - - we failed to contact the supplied identity server - - Returns: - True on success, otherwise False if the identity - server doesn't support unbinding - """ - - if not valid_id_server_location(id_server): - raise SynapseError( - 400, - "id_server must be a valid hostname with optional port and path components", - ) - - url = "https://%s/_matrix/identity/v2/3pid/unbind" % (id_server,) - url_bytes = b"/_matrix/identity/v2/3pid/unbind" - - content = { - "mxid": mxid, - "threepid": {"medium": medium, "address": address}, - } - - # we abuse the federation http client to sign the request, but we have to send it - # using the normal http client since we don't want the SRV lookup and want normal - # 'browser-like' HTTPS. - auth_headers = self.federation_http_client.build_auth_headers( - destination=None, - method=b"POST", - url_bytes=url_bytes, - content=content, - destination_is=id_server.encode("ascii"), - ) - headers = {b"Authorization": auth_headers} - - try: - # Use the blacklisting http client as this call is only to identity servers - # provided by a client - await self._http_client.post_json_get_json(url, content, headers) - changed = True - except HttpResponseException as e: - changed = False - if e.code in (400, 404, 501): - # The remote server probably doesn't support unbinding (yet) - logger.warning("Received %d response while unbinding threepid", e.code) - else: - logger.error("Failed to unbind threepid on identity server: %s", e) - raise SynapseError(500, "Failed to contact identity server") - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - - await self.store.remove_user_bound_threepid(mxid, medium, address, id_server) - - return changed - - async def send_threepid_validation( - self, - email_address: str, - client_secret: str, - send_attempt: int, - send_email_func: Callable[[str, str, str, str], Awaitable], - next_link: Optional[str] = None, - ) -> str: - """Send a threepid validation email for password reset or - registration purposes - - Args: - email_address: The user's email address - client_secret: The provided client secret - send_attempt: Which send attempt this is - send_email_func: A function that takes an email address, token, - client_secret and session_id, sends an email - and returns an Awaitable. - next_link: The URL to redirect the user to after validation - - Returns: - The new session_id upon success - - Raises: - SynapseError is an error occurred when sending the email - """ - # Check that this email/client_secret/send_attempt combo is new or - # greater than what we've seen previously - session = await self.store.get_threepid_validation_session( - "email", client_secret, address=email_address, validated=False - ) - - # Check to see if a session already exists and that it is not yet - # marked as validated - if session and session.validated_at is None: - session_id = session.session_id - last_send_attempt = session.last_send_attempt - - # Check that the send_attempt is higher than previous attempts - if send_attempt <= last_send_attempt: - # If not, just return a success without sending an email - return session_id - else: - # An non-validated session does not exist yet. - # Generate a session id - session_id = random_string(16) - - if next_link: - # Manipulate the next_link to add the sid, because the caller won't get - # it until we send a response, by which time we've sent the mail. - if "?" in next_link: - next_link += "&" - else: - next_link += "?" - next_link += "sid=" + urllib.parse.quote(session_id) - - # Generate a new validation token - token = random_string(32) - - # Send the mail with the link containing the token, client_secret - # and session_id - try: - await send_email_func(email_address, token, client_secret, session_id) - except Exception: - logger.exception( - "Error sending threepid validation email to %s", email_address - ) - raise SynapseError(500, "An error was encountered when sending the email") - - token_expires = ( - self.hs.get_clock().time_msec() - + self.hs.config.email.email_validation_token_lifetime - ) - - await self.store.start_or_continue_validation_session( - "email", - email_address, - session_id, - client_secret, - send_attempt, - next_link, - token, - token_expires, - ) - - return session_id - - async def requestMsisdnToken( - self, - id_server: str, - country: str, - phone_number: str, - client_secret: str, - send_attempt: int, - next_link: Optional[str] = None, - ) -> JsonDict: - """ - Request an external server send an SMS message on our behalf for the purposes of - threepid validation. - Args: - id_server: The identity server to proxy to - country: The country code of the phone number - phone_number: The number to send the message to - client_secret: The unique client_secret sends by the user - send_attempt: Which attempt this is - next_link: A link to redirect the user to once they submit the token - - Returns: - The json response body from the server - """ - params = { - "country": country, - "phone_number": phone_number, - "client_secret": client_secret, - "send_attempt": send_attempt, - } - if next_link: - params["next_link"] = next_link - - try: - data = await self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken", - params, - ) - except HttpResponseException as e: - logger.info("Proxied requestToken failed: %r", e) - raise e.to_synapse_error() - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - - # we need to tell the client to send the token back to us, since it doesn't - # otherwise know where to send it, so add submit_url response parameter - # (see also MSC2078) - data["submit_url"] = ( - self.hs.config.server.public_baseurl - + "_matrix/client/unstable/add_threepid/msisdn/submit_token" - ) - return data - - async def validate_threepid_session( - self, client_secret: str, sid: str - ) -> Optional[JsonDict]: - """Validates a threepid session with only the client secret and session ID - Tries validating against any configured account_threepid_delegates as well as locally. - - Args: - client_secret: A secret provided by the client - sid: The ID of the session - - Returns: - The json response if validation was successful, otherwise None - """ - # XXX: We shouldn't need to keep wrapping and unwrapping this value - threepid_creds = {"client_secret": client_secret, "sid": sid} - - # We don't actually know which medium this 3PID is. Thus we first assume it's email, - # and if validation fails we try msisdn - - # Try to validate as email - if self.hs.config.email.can_verify_email: - # Get a validated session matching these details - validation_session = await self.store.get_threepid_validation_session( - "email", client_secret, sid=sid, validated=True - ) - if validation_session: - return attr.asdict(validation_session) - - # Try to validate as msisdn - if self.hs.config.registration.account_threepid_delegate_msisdn: - # Ask our delegated msisdn identity server - return await self.threepid_from_creds( - self.hs.config.registration.account_threepid_delegate_msisdn, - threepid_creds, - ) - - return None - - async def proxy_msisdn_submit_token( - self, id_server: str, client_secret: str, sid: str, token: str - ) -> JsonDict: - """Proxy a POST submitToken request to an identity server for verification purposes - - Args: - id_server: The identity server URL to contact - client_secret: Secret provided by the client - sid: The ID of the session - token: The verification token - - Raises: - SynapseError: If we failed to contact the identity server - - Returns: - The response dict from the identity server - """ - body = {"client_secret": client_secret, "sid": sid, "token": token} - - try: - return await self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken", - body, - ) - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - except HttpResponseException as e: - logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) - raise SynapseError(400, "Error contacting the identity server") - - async def lookup_3pid( - self, id_server: str, medium: str, address: str, id_access_token: str - ) -> Optional[str]: - """Looks up a 3pid in the passed identity server. - - Args: - id_server: The server name (including port, if required) - of the identity server to use. - medium: The type of the third party identifier (e.g. "email"). - address: The third party identifier (e.g. "foo@example.com"). - id_access_token: The access token to authenticate to the identity - server with - - Returns: - the matrix ID of the 3pid, or None if it is not recognized. - """ - - try: - results = await self._lookup_3pid_v2( - id_server, id_access_token, medium, address - ) - return results - except Exception as e: - logger.warning("Error when looking up hashing details: %s", e) - return None - - async def _lookup_3pid_v2( - self, id_server: str, id_access_token: str, medium: str, address: str - ) -> Optional[str]: - """Looks up a 3pid in the passed identity server using v2 lookup. - - Args: - id_server: The server name (including port, if required) - of the identity server to use. - id_access_token: The access token to authenticate to the identity server with - medium: The type of the third party identifier (e.g. "email"). - address: The third party identifier (e.g. "foo@example.com"). - - Returns: - the matrix ID of the 3pid, or None if it is not recognised. - """ - # Check what hashing details are supported by this identity server - try: - hash_details = await self._http_client.get_json( - "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server), - {"access_token": id_access_token}, - ) - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - - if not isinstance(hash_details, dict): - logger.warning( - "Got non-dict object when checking hash details of %s%s: %s", - id_server_scheme, - id_server, - hash_details, - ) - raise SynapseError( - 400, - "Non-dict object from %s%s during v2 hash_details request: %s" - % (id_server_scheme, id_server, hash_details), - ) - - # Extract information from hash_details - supported_lookup_algorithms = hash_details.get("algorithms") - lookup_pepper = hash_details.get("lookup_pepper") - if ( - not supported_lookup_algorithms - or not isinstance(supported_lookup_algorithms, list) - or not lookup_pepper - or not isinstance(lookup_pepper, str) - ): - raise SynapseError( - 400, - "Invalid hash details received from identity server %s%s: %s" - % (id_server_scheme, id_server, hash_details), - ) - - # Check if any of the supported lookup algorithms are present - if LookupAlgorithm.SHA256 in supported_lookup_algorithms: - # Perform a hashed lookup - lookup_algorithm = LookupAlgorithm.SHA256 - - # Hash address, medium and the pepper with sha256 - to_hash = "%s %s %s" % (address, medium, lookup_pepper) - lookup_value = sha256_and_url_safe_base64(to_hash) - - elif LookupAlgorithm.NONE in supported_lookup_algorithms: - # Perform a non-hashed lookup - lookup_algorithm = LookupAlgorithm.NONE - - # Combine together plaintext address and medium - lookup_value = "%s %s" % (address, medium) - - else: - logger.warning( - "None of the provided lookup algorithms of %s are supported: %s", - id_server, - supported_lookup_algorithms, - ) - raise SynapseError( - 400, - "Provided identity server does not support any v2 lookup " - "algorithms that this homeserver supports.", - ) - - # Authenticate with identity server given the access token from the client - headers = {"Authorization": create_id_access_token_header(id_access_token)} - - try: - lookup_results = await self._http_client.post_json_get_json( - "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server), - { - "addresses": [lookup_value], - "algorithm": lookup_algorithm, - "pepper": lookup_pepper, - }, - headers=headers, - ) - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - except Exception as e: - logger.warning("Error when performing a v2 3pid lookup: %s", e) - raise SynapseError( - 500, "Unknown error occurred during identity server lookup" - ) - - # Check for a mapping from what we looked up to an MXID - if "mappings" not in lookup_results or not isinstance( - lookup_results["mappings"], dict - ): - logger.warning("No results from 3pid lookup") - return None - - # Return the MXID if it's available, or None otherwise - mxid = lookup_results["mappings"].get(lookup_value) - return mxid - - async def ask_id_server_for_third_party_invite( - self, - requester: Requester, - id_server: str, - medium: str, - address: str, - room_id: str, - inviter_user_id: str, - room_alias: str, - room_avatar_url: str, - room_join_rules: str, - room_name: str, - room_type: Optional[str], - inviter_display_name: str, - inviter_avatar_url: str, - id_access_token: str, - ) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]: - """ - Asks an identity server for a third party invite. - - Args: - requester - id_server: hostname + optional port for the identity server. - medium: The literal string "email". - address: The third party address being invited. - room_id: The ID of the room to which the user is invited. - inviter_user_id: The user ID of the inviter. - room_alias: An alias for the room, for cosmetic notifications. - room_avatar_url: The URL of the room's avatar, for cosmetic - notifications. - room_join_rules: The join rules of the email (e.g. "public"). - room_name: The m.room.name of the room. - room_type: The type of the room from its m.room.create event (e.g "m.space"). - inviter_display_name: The current display name of the - inviter. - inviter_avatar_url: The URL of the inviter's avatar. - id_access_token: The access token to authenticate to the identity - server with - - Returns: - A tuple containing: - token: The token which must be signed to prove authenticity. - public_keys ([{"public_key": str, "key_validity_url": str}]): - public_key is a base64-encoded ed25519 public key. - fallback_public_key: One element from public_keys. - display_name: A user-friendly name to represent the invited user. - """ - invite_config = { - "medium": medium, - "address": address, - "room_id": room_id, - "room_alias": room_alias, - "room_avatar_url": room_avatar_url, - "room_join_rules": room_join_rules, - "room_name": room_name, - "sender": inviter_user_id, - "sender_display_name": inviter_display_name, - "sender_avatar_url": inviter_avatar_url, - } - - if room_type is not None: - invite_config["room_type"] = room_type - - # If a custom web client location is available, include it in the request. - if self._web_client_location: - invite_config["org.matrix.web_client_location"] = self._web_client_location - - # Add the identity service access token to the JSON body and use the v2 - # Identity Service endpoints - data = None - - key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % ( - id_server_scheme, - id_server, - ) - - url = "%s%s/_matrix/identity/v2/store-invite" % (id_server_scheme, id_server) - try: - data = await self._http_client.post_json_get_json( - url, - invite_config, - {"Authorization": create_id_access_token_header(id_access_token)}, - ) - except RequestTimedOutError: - raise SynapseError(500, "Timed out contacting identity server") - - token = data["token"] - public_keys = data.get("public_keys", []) - if "public_key" in data: - fallback_public_key = { - "public_key": data["public_key"], - "key_validity_url": key_validity_url, - } - else: - fallback_public_key = public_keys[0] - - if not public_keys: - public_keys.append(fallback_public_key) - display_name = data["display_name"] - return token, public_keys, fallback_public_key, display_name - - -def create_id_access_token_header(id_access_token: str) -> List[str]: - """Create an Authorization header for passing to SimpleHttpClient as the header value - of an HTTP request. - - Args: - id_access_token: An identity server access token. - - Returns: - The ascii-encoded bearer token encased in a list. - """ - # Prefix with Bearer - bearer_token = "Bearer %s" % id_access_token - - # Encode headers to standard ascii - bearer_token.encode("ascii") - - # Return as a list as that's how SimpleHttpClient takes header values - return [bearer_token] - - -class LookupAlgorithm: - """ - Supported hashing algorithms when performing a 3PID lookup. - - SHA256 - Hashing an (address, medium, pepper) combo with sha256, then url-safe base64 - encoding - NONE - Not performing any hashing. Simply sending an (address, medium) combo in plaintext - """ - - SHA256 = "sha256" - NONE = "none" diff --git a/synapse/handlers/jwt.py b/synapse/handlers/jwt.py
index 5fa7a305ad..400f3a59aa 100644 --- a/synapse/handlers/jwt.py +++ b/synapse/handlers/jwt.py
@@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple from authlib.jose import JsonWebToken, JWTClaims from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError @@ -36,11 +36,12 @@ class JwtHandler: self.jwt_secret = hs.config.jwt.jwt_secret self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim + self.jwt_display_name_claim = hs.config.jwt.jwt_display_name_claim self.jwt_algorithm = hs.config.jwt.jwt_algorithm self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_audiences = hs.config.jwt.jwt_audiences - def validate_login(self, login_submission: JsonDict) -> str: + def validate_login(self, login_submission: JsonDict) -> Tuple[str, Optional[str]]: """ Authenticates the user for the /login API @@ -49,7 +50,8 @@ class JwtHandler: (including 'type' and other relevant fields) Returns: - The user ID that is logging in. + A tuple of (user_id, display_name) of the user that is logging in. + If the JWT does not contain a display name, the second element of the tuple will be None. Raises: LoginError if there was an authentication problem. @@ -109,4 +111,10 @@ class JwtHandler: if user is None: raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) - return UserID(user, self.hs.hostname).to_string() + default_display_name = None + if self.jwt_display_name_claim: + display_name_claim = claims.get(self.jwt_display_name_claim) + if display_name_claim is not None: + default_display_name = display_name_claim + + return UserID(user, self.hs.hostname).to_string(), default_display_name diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 5aa48230ec..cb6de02309 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -143,9 +143,9 @@ class MessageHandler: elif membership == Membership.LEAVE: key = (event_type, state_key) # If the membership is not JOIN, then the event ID should exist. - assert ( - membership_event_id is not None - ), "check_user_in_room_or_world_readable returned invalid data" + assert membership_event_id is not None, ( + "check_user_in_room_or_world_readable returned invalid data" + ) room_state = await self._state_storage_controller.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) @@ -196,7 +196,9 @@ class MessageHandler: AuthError (403) if the user doesn't have permission to view members of this room. """ - state_filter = state_filter or StateFilter.all() + if state_filter is None: + state_filter = StateFilter.all() + user_id = requester.user.to_string() if at_token: @@ -240,9 +242,9 @@ class MessageHandler: room_state = await self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: # If the membership is not JOIN, then the event ID should exist. - assert ( - membership_event_id is not None - ), "check_user_in_room_or_world_readable returned invalid data" + assert membership_event_id is not None, ( + "check_user_in_room_or_world_readable returned invalid data" + ) room_state_events = ( await self._state_storage_controller.get_state_for_events( [membership_event_id], state_filter=state_filter @@ -493,6 +495,7 @@ class EventCreationHandler: self._instance_name = hs.get_instance_name() self._notifier = hs.get_notifier() self._worker_lock_handler = hs.get_worker_locks_handler() + self._policy_handler = hs.get_room_policy_handler() self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state @@ -642,11 +645,33 @@ class EventCreationHandler: """ await self.auth_blocking.check_auth_blocking(requester=requester) - if event_dict["type"] == EventTypes.Message: - requester_suspended = await self.store.get_user_suspended_status( - requester.user.to_string() - ) - if requester_suspended: + requester_suspended = await self.store.get_user_suspended_status( + requester.user.to_string() + ) + if requester_suspended: + # We want to allow suspended users to perform "corrective" actions + # asked of them by server admins, such as redact their messages and + # leave rooms. + if event_dict["type"] in ["m.room.redaction", "m.room.member"]: + if event_dict["type"] == "m.room.redaction": + event = await self.store.get_event( + event_dict["content"]["redacts"], allow_none=True + ) + if event: + if event.sender != requester.user.to_string(): + raise SynapseError( + 403, + "You can only redact your own events while account is suspended.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + if event_dict["type"] == "m.room.member": + if event_dict["content"]["membership"] != "leave": + raise SynapseError( + 403, + "Changing membership while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + else: raise SynapseError( 403, "Sending messages while account is suspended is not allowed.", @@ -1084,6 +1109,18 @@ class EventCreationHandler: event.sender, ) + policy_allowed = await self._policy_handler.is_event_allowed(event) + if not policy_allowed: + logger.warning( + "Event not allowed by policy server, rejecting %s", + event.event_id, + ) + raise SynapseError( + 403, + "This message has been rejected as probable spam", + Codes.FORBIDDEN, + ) + spam_check_result = ( await self._spam_checker_module_callbacks.check_event_for_spam( event @@ -1095,7 +1132,7 @@ class EventCreationHandler: [code, dict] = spam_check_result raise SynapseError( 403, - "This message had been rejected as probable spam", + "This message has been rejected as probable spam", code, dict, ) @@ -1225,10 +1262,9 @@ class EventCreationHandler: ) if prev_event_ids is not None: - assert ( - len(prev_event_ids) <= 10 - ), "Attempting to create an event with %i prev_events" % ( - len(prev_event_ids), + assert len(prev_event_ids) <= 10, ( + "Attempting to create an event with %i prev_events" + % (len(prev_event_ids),) ) else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) @@ -1243,12 +1279,14 @@ class EventCreationHandler: # Allow an event to have empty list of prev_event_ids # only if it has auth_event_ids. or auth_event_ids - ), "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids" + ), ( + "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids" + ) else: # we now ought to have some prev_events (unless it's a create event). - assert ( - builder.type == EventTypes.Create or prev_event_ids - ), "Attempting to create a non-m.room.create event with no prev_events" + assert builder.type == EventTypes.Create or prev_event_ids, ( + "Attempting to create a non-m.room.create event with no prev_events" + ) if for_batch: assert prev_event_ids is not None @@ -1439,6 +1477,12 @@ class EventCreationHandler: ) return prev_event + if not event.is_state() and event.type in [ + EventTypes.Message, + EventTypes.Encrypted, + ]: + await self.store.set_room_participation(event.user_id, event.room_id) + if event.internal_metadata.is_out_of_band_membership(): # the only sort of out-of-band-membership events we expect to see here are # invite rejections and rescinded knocks that we have generated ourselves. diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 22b59829fa..4b85282c1e 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py
@@ -31,6 +31,7 @@ from typing import ( List, Optional, Type, + TypedDict, TypeVar, Union, ) @@ -52,7 +53,6 @@ from pymacaroons.exceptions import ( MacaroonInitException, MacaroonInvalidSignatureException, ) -from typing_extensions import TypedDict from twisted.web.client import readBody from twisted.web.http_headers import Headers @@ -382,7 +382,12 @@ class OidcProvider: self._macaroon_generaton = macaroon_generator self._config = provider - self._callback_url: str = hs.config.oidc.oidc_callback_url + + self._callback_url: str + if provider.redirect_uri is not None: + self._callback_url = provider.redirect_uri + else: + self._callback_url = hs.config.oidc.oidc_callback_url # Calculate the prefix for OIDC callback paths based on the public_baseurl. # We'll insert this into the Path= parameter of any session cookies we set. @@ -462,6 +467,10 @@ class OidcProvider: self._sso_handler.register_identity_provider(self) + self.passthrough_authorization_parameters = ( + provider.passthrough_authorization_parameters + ) + def _validate_metadata(self, m: OpenIDProviderMetadata) -> None: """Verifies the provider metadata. @@ -578,6 +587,24 @@ class OidcProvider: ) @property + def _uses_access_token(self) -> bool: + """Return True if the `access_token` will be used during the login process. + + This is useful to determine whether the access token + returned by the identity provider, and + any related metadata (such as the `at_hash` field in + the ID token), should be validated. + """ + # Currently, Synapse only uses the access_token to fetch user metadata + # from the userinfo endpoint. Therefore we only have a single criteria + # to check right now but this may change in the future and this function + # should be updated if more usages are introduced. + # + # For example, if we start to use the access_token given to us by the + # IdP for more things, such as accessing Resource Server APIs. + return self._uses_userinfo + + @property def issuer(self) -> str: """The issuer identifying this provider.""" return self._config.issuer @@ -640,6 +667,11 @@ class OidcProvider: elif self._config.pkce_method == "never": metadata.pop("code_challenge_methods_supported", None) + if self._config.id_token_signing_alg_values_supported: + metadata["id_token_signing_alg_values_supported"] = ( + self._config.id_token_signing_alg_values_supported + ) + self._validate_metadata(metadata) return metadata @@ -943,9 +975,16 @@ class OidcProvider: "nonce": nonce, "client_id": self._client_auth.client_id, } - if "access_token" in token: + if self._uses_access_token and "access_token" in token: # If we got an `access_token`, there should be an `at_hash` claim - # in the `id_token` that we can check against. + # in the `id_token` that we can check against. Setting this + # instructs authlib to check the value of `at_hash` in the + # ID token. + # + # We only need to verify the access token if we actually make + # use of it. Which currently only happens when we need to fetch + # the user's information from the userinfo_endpoint. Thus, this + # check is also gated on self._uses_userinfo. claims_params["access_token"] = token["access_token"] claims_options = {"iss": {"values": [metadata["issuer"]]}} @@ -995,14 +1034,27 @@ class OidcProvider: when everything is done (or None for UI Auth) ui_auth_session_id: The session ID of the ongoing UI Auth (or None if this is a login). - Returns: The redirect URL to the authorization endpoint. """ state = generate_token() - nonce = generate_token() + + # Generate a nonce 32 characters long. When encoded with base64url later on, + # the nonce will be 43 characters when sent to the identity provider. + # + # While RFC7636 does not specify a minimum length for the `nonce` + # parameter, the TI-Messenger IDP_FD spec v1.7.3 does require it to be + # between 43 and 128 characters. This spec concerns using Matrix for + # communication in German healthcare. + # + # As increasing the length only strengthens security, we use this length + # to allow TI-Messenger deployments using Synapse to satisfy this + # external spec. + # + # See https://github.com/element-hq/synapse/pull/18109 for more context. + nonce = generate_token(length=32) code_verifier = "" if not client_redirect_url: @@ -1054,6 +1106,13 @@ class OidcProvider: ) ) + # add passthrough additional authorization parameters + passthrough_authorization_parameters = self.passthrough_authorization_parameters + for parameter in passthrough_authorization_parameters: + parameter_value = parse_string(request, parameter) + if parameter_value: + additional_authorization_parameters.update({parameter: parameter_value}) + authorization_endpoint = metadata.get("authorization_endpoint") return prepare_grant_uri( authorization_endpoint, @@ -1716,17 +1775,12 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if display_name == "": display_name = None - emails: List[str] = [] - email = render_template_field(self._config.email_template) - if email: - emails.append(email) - picture = self._config.picture_template.render(user=userinfo).strip() return UserAttributeDict( localpart=localpart, display_name=display_name, - emails=emails, + emails=[], # 3PIDs are not supported picture=picture, confirm_localpart=self._config.confirm_localpart, ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6fd7afa280..365c9cabcb 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py
@@ -507,15 +507,17 @@ class PaginationHandler: # Initially fetch the events from the database. With any luck, we can return # these without blocking on backfill (handled below). - events, next_key = ( - await self.store.paginate_room_events_by_topological_ordering( - room_id=room_id, - from_key=from_token.room_key, - to_key=to_room_key, - direction=pagin_config.direction, - limit=pagin_config.limit, - event_filter=event_filter, - ) + ( + events, + next_key, + limited, + ) = await self.store.paginate_room_events_by_topological_ordering( + room_id=room_id, + from_key=from_token.room_key, + to_key=to_room_key, + direction=pagin_config.direction, + limit=pagin_config.limit, + event_filter=event_filter, ) if pagin_config.direction == Direction.BACKWARDS: @@ -575,25 +577,31 @@ class PaginationHandler: or missing_too_many_events or not_enough_events_to_fill_response ): - did_backfill = await self.hs.get_federation_handler().maybe_backfill( + # Historical Note: There used to be a check here for if backfill was + # successful or not + await self.hs.get_federation_handler().maybe_backfill( room_id, curr_topo, limit=pagin_config.limit, ) - # If we did backfill something, refetch the events from the database to - # catch anything new that might have been added since we last fetched. - if did_backfill: - events, next_key = ( - await self.store.paginate_room_events_by_topological_ordering( - room_id=room_id, - from_key=from_token.room_key, - to_key=to_room_key, - direction=pagin_config.direction, - limit=pagin_config.limit, - event_filter=event_filter, - ) - ) + # Regardless if we backfilled or not, another worker or even a + # simultaneous request may have backfilled for us while we were held + # behind the linearizer. This should not have too much additional + # database load as it will only be triggered if a backfill *might* have + # been needed + ( + events, + next_key, + limited, + ) = await self.store.paginate_room_events_by_topological_ordering( + room_id=room_id, + from_key=from_token.room_key, + to_key=to_room_key, + direction=pagin_config.direction, + limit=pagin_config.limit, + event_filter=event_filter, + ) else: # Otherwise, we can backfill in the background for eventual # consistency's sake but we don't need to block the client waiting @@ -608,6 +616,15 @@ class PaginationHandler: next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key) + # We might have hit some internal filtering first, for example rejected + # events. Ensure we return a pagination token then. + if not events and limited: + return { + "chunk": [], + "start": await from_token.to_string(self.store), + "end": await next_token.to_string(self.store), + } + # if no events are returned from pagination, that implies # we have reached the end of the available events. # In that case we do not return end, to tell the client diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 37ee625f71..390cafa8f6 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -71,6 +71,7 @@ user state; this device follows the normal timeout logic (see above) and will automatically be replaced with any information from currently available devices. """ + import abc import contextlib import itertools @@ -493,9 +494,9 @@ class WorkerPresenceHandler(BasePresenceHandler): # The number of ongoing syncs on this process, by (user ID, device ID). # Empty if _presence_enabled is false. - self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = ( - {} - ) + self._user_device_to_num_current_syncs: Dict[ + Tuple[str, Optional[str]], int + ] = {} self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() @@ -818,9 +819,9 @@ class PresenceHandler(BasePresenceHandler): # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = ( - {} - ) + self._user_device_to_num_current_syncs: Dict[ + Tuple[str, Optional[str]], int + ] = {} # Keeps track of the number of *ongoing* syncs on other processes. # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 6663d4b271..cdc388b4ab 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -22,6 +22,7 @@ import logging import random from typing import TYPE_CHECKING, List, Optional, Union +from synapse.api.constants import ProfileFields from synapse.api.errors import ( AuthError, Codes, @@ -31,7 +32,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict, JsonValue, Requester, UserID, create_requester from synapse.util.caches.descriptors import cached from synapse.util.stringutils import parse_and_validate_mxc_uri @@ -42,6 +43,8 @@ logger = logging.getLogger(__name__) MAX_DISPLAYNAME_LEN = 256 MAX_AVATAR_URL_LEN = 1000 +# Field name length is specced at 255 bytes. +MAX_CUSTOM_FIELD_LEN = 255 class ProfileHandler: @@ -74,17 +77,42 @@ class ProfileHandler: self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict: + """ + Get a user's profile as a JSON dictionary. + + Args: + user_id: The user to fetch the profile of. + ignore_backoff: True to ignore backoff when fetching over federation. + + Returns: + A JSON dictionary. For local queries this will include the displayname and avatar_url + fields, if set. For remote queries it may contain arbitrary information. + """ target_user = UserID.from_string(user_id) if self.hs.is_mine(target_user): profileinfo = await self.store.get_profileinfo(target_user) - if profileinfo.display_name is None and profileinfo.avatar_url is None: + extra_fields = {} + if self.hs.config.experimental.msc4133_enabled: + extra_fields = await self.store.get_profile_fields(target_user) + + if ( + profileinfo.display_name is None + and profileinfo.avatar_url is None + and not extra_fields + ): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) - return { - "displayname": profileinfo.display_name, - "avatar_url": profileinfo.avatar_url, - } + # Do not include display name or avatar if unset. + ret = {} + if profileinfo.display_name is not None: + ret[ProfileFields.DISPLAYNAME] = profileinfo.display_name + if profileinfo.avatar_url is not None: + ret[ProfileFields.AVATAR_URL] = profileinfo.avatar_url + if extra_fields: + ret.update(extra_fields) + + return ret else: try: result = await self.federation.make_query( @@ -107,6 +135,15 @@ class ProfileHandler: raise e.to_synapse_error() async def get_displayname(self, target_user: UserID) -> Optional[str]: + """ + Fetch a user's display name from their profile. + + Args: + target_user: The user to fetch the display name of. + + Returns: + The user's display name or None if unset. + """ if self.hs.is_mine(target_user): try: displayname = await self.store.get_profile_displayname(target_user) @@ -203,6 +240,15 @@ class ProfileHandler: await self._update_join_states(requester, target_user) async def get_avatar_url(self, target_user: UserID) -> Optional[str]: + """ + Fetch a user's avatar URL from their profile. + + Args: + target_user: The user to fetch the avatar URL of. + + Returns: + The user's avatar URL or None if unset. + """ if self.hs.is_mine(target_user): try: avatar_url = await self.store.get_profile_avatar_url(target_user) @@ -322,9 +368,9 @@ class ProfileHandler: server_name = host if self._is_mine_server_name(server_name): - media_info: Optional[Union[LocalMedia, RemoteMedia]] = ( - await self.store.get_local_media(media_id) - ) + media_info: Optional[ + Union[LocalMedia, RemoteMedia] + ] = await self.store.get_local_media(media_id) else: media_info = await self.store.get_cached_remote_media(server_name, media_id) @@ -370,6 +416,110 @@ class ProfileHandler: return True + async def get_profile_field( + self, target_user: UserID, field_name: str + ) -> JsonValue: + """ + Fetch a user's profile from the database for local users and over federation + for remote users. + + Args: + target_user: The user ID to fetch the profile for. + field_name: The field to fetch the profile for. + + Returns: + The value for the profile field or None if the field does not exist. + """ + if self.hs.is_mine(target_user): + try: + field_value = await self.store.get_profile_field( + target_user, field_name + ) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) + raise + + return field_value + else: + try: + result = await self.federation.make_query( + destination=target_user.domain, + query_type="profile", + args={"user_id": target_user.to_string(), "field": field_name}, + ignore_backoff=True, + ) + except RequestSendFailed as e: + raise SynapseError(502, "Failed to fetch profile") from e + except HttpResponseException as e: + raise e.to_synapse_error() + + return result.get(field_name) + + async def set_profile_field( + self, + target_user: UserID, + requester: Requester, + field_name: str, + new_value: JsonValue, + by_admin: bool = False, + deactivation: bool = False, + ) -> None: + """Set a new profile field for a user. + + Args: + target_user: the user whose profile is to be changed. + requester: The user attempting to make this change. + field_name: The name of the profile field to update. + new_value: The new field value for this user. + by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. + """ + if not self.hs.is_mine(target_user): + raise SynapseError(400, "User is not hosted on this homeserver") + + if not by_admin and target_user != requester.user: + raise AuthError(403, "Cannot set another user's profile") + + await self.store.set_profile_field(target_user, field_name, new_value) + + # Custom fields do not propagate into the user directory *or* rooms. + profile = await self.store.get_profileinfo(target_user) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + + async def delete_profile_field( + self, + target_user: UserID, + requester: Requester, + field_name: str, + by_admin: bool = False, + deactivation: bool = False, + ) -> None: + """Delete a field from a user's profile. + + Args: + target_user: the user whose profile is to be changed. + requester: The user attempting to make this change. + field_name: The name of the profile field to remove. + by_admin: Whether this change was made by an administrator. + deactivation: Whether this change was made while deactivating the user. + """ + if not self.hs.is_mine(target_user): + raise SynapseError(400, "User is not hosted on this homeserver") + + if not by_admin and target_user != requester.user: + raise AuthError(400, "Cannot set another user's profile") + + await self.store.delete_profile_field(target_user, field_name) + + # Custom fields do not propagate into the user directory *or* rooms. + profile = await self.store.get_profileinfo(target_user) + await self._third_party_rules.on_profile_update( + target_user.to_string(), profile, by_admin, deactivation + ) + async def on_profile_query(self, args: JsonDict) -> JsonDict: """Handles federation profile query requests.""" @@ -386,13 +536,24 @@ class ProfileHandler: just_field = args.get("field", None) - response = {} + response: JsonDict = {} try: - if just_field is None or just_field == "displayname": + if just_field is None or just_field == ProfileFields.DISPLAYNAME: response["displayname"] = await self.store.get_profile_displayname(user) - if just_field is None or just_field == "avatar_url": + if just_field is None or just_field == ProfileFields.AVATAR_URL: response["avatar_url"] = await self.store.get_profile_avatar_url(user) + + if self.hs.config.experimental.msc4133_enabled: + if just_field is None: + response.update(await self.store.get_profile_fields(user)) + elif just_field not in ( + ProfileFields.DISPLAYNAME, + ProfileFields.AVATAR_URL, + ): + response[just_field] = await self.store.get_profile_field( + user, just_field + ) except StoreError as e: if e.code == 404: raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) @@ -403,6 +564,12 @@ class ProfileHandler: async def _update_join_states( self, requester: Requester, target_user: UserID ) -> None: + """ + Update the membership events of each room the user is joined to with the + new profile information. + + Note that this stomps over any custom display name or avatar URL in member events. + """ if not self.hs.is_mine(target_user): return diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c200e29569..8dd687c455 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -23,10 +23,9 @@ """Contains functions for registering clients.""" import logging -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, TypedDict from prometheus_client import Counter -from typing_extensions import TypedDict from synapse import types from synapse.api.constants import ( @@ -44,7 +43,6 @@ from synapse.api.errors import ( SynapseError, ) from synapse.appservice import ApplicationService -from synapse.config.server import is_threepid_reserved from synapse.handlers.device import DeviceHandler from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet @@ -109,13 +107,13 @@ class RegistrationHandler: self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() - self.identity_handler = self.hs.get_identity_handler() self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() self._account_validity_handler = hs.get_account_validity_handler() self._user_consent_version = self.hs.config.consent.user_consent_version self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self._server_name = hs.hostname + self._user_types_config = hs.config.user_types self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker @@ -160,7 +158,10 @@ class RegistrationHandler: if not localpart: raise SynapseError(400, "User ID cannot be empty", Codes.INVALID_USERNAME) - if localpart[0] == "_": + if ( + localpart[0] == "_" + and not self.hs.config.registration.allow_underscore_prefixed_localpart + ): raise SynapseError( 400, "User ID may not begin with _", Codes.INVALID_USERNAME ) @@ -304,6 +305,9 @@ class RegistrationHandler: elif default_display_name is None: default_display_name = localpart + if user_type is None: + user_type = self._user_types_config.default_user_type + await self.register_with_store( user_id=user_id, password_hash=password_hash, @@ -380,19 +384,6 @@ class RegistrationHandler: user_id, ) - # Bind any specified emails to this account - current_time = self.hs.get_clock().time_msec() - for email in bind_emails: - # generate threepid dict - threepid_dict = { - "medium": "email", - "address": email, - "validated_at": current_time, - } - - # Bind email to new account - await self._register_email_threepid(user_id, threepid_dict, None) - return user_id async def _create_and_join_rooms(self, user_id: str) -> None: @@ -630,7 +621,9 @@ class RegistrationHandler: """ await self._auto_join_rooms(user_id) - async def appservice_register(self, user_localpart: str, as_token: str) -> str: + async def appservice_register( + self, user_localpart: str, as_token: str + ) -> Tuple[str, ApplicationService]: user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) @@ -653,7 +646,7 @@ class RegistrationHandler: appservice_id=service_id, create_profile_with_displayname=user.localpart, ) - return user_id + return (user_id, service) def check_user_id_not_appservice_exclusive( self, user_id: str, allowed_appservice: Optional[ApplicationService] = None @@ -941,21 +934,6 @@ class RegistrationHandler: ) return - if auth_result and LoginType.EMAIL_IDENTITY in auth_result: - threepid = auth_result[LoginType.EMAIL_IDENTITY] - # Necessary due to auth checks prior to the threepid being - # written to the db - if is_threepid_reserved( - self.hs.config.server.mau_limits_reserved_threepids, threepid - ): - await self.store.upsert_monthly_active_user(user_id) - - await self._register_email_threepid(user_id, threepid, access_token) - - if auth_result and LoginType.MSISDN in auth_result: - threepid = auth_result[LoginType.MSISDN] - await self._register_msisdn_threepid(user_id, threepid) - if auth_result and LoginType.TERMS in auth_result: # The terms type should only exist if consent is enabled. assert self._user_consent_version is not None @@ -971,86 +949,3 @@ class RegistrationHandler: logger.info("%s has consented to the privacy policy", user_id) await self.store.user_set_consent_version(user_id, consent_version) await self.post_consent_actions(user_id) - - async def _register_email_threepid( - self, user_id: str, threepid: dict, token: Optional[str] - ) -> None: - """Add an email address as a 3pid identifier - - Also adds an email pusher for the email address, if configured in the - HS config - - Must be called on master. - - Args: - user_id: id of user - threepid: m.login.email.identity auth response - token: access_token for the user, or None if not logged in. - """ - reqd = ("medium", "address", "validated_at") - if any(x not in threepid for x in reqd): - # This will only happen if the ID server returns a malformed response - logger.info("Can't add incomplete 3pid") - return - - await self._auth_handler.add_threepid( - user_id, - threepid["medium"], - threepid["address"], - threepid["validated_at"], - ) - - # And we add an email pusher for them by default, but only - # if email notifications are enabled (so people don't start - # getting mail spam where they weren't before if email - # notifs are set up on a homeserver) - if ( - self.hs.config.email.email_enable_notifs - and self.hs.config.email.email_notif_for_new_users - and token - ): - # Pull the ID of the access token back out of the db - # It would really make more sense for this to be passed - # up when the access token is saved, but that's quite an - # invasive change I'd rather do separately. - user_tuple = await self.store.get_user_by_access_token(token) - # The token better still exist. - assert user_tuple - device_id = user_tuple.device_id - - await self.pusher_pool.add_or_update_pusher( - user_id=user_id, - device_id=device_id, - kind="email", - app_id="m.email", - app_display_name="Email Notifications", - device_display_name=threepid["address"], - pushkey=threepid["address"], - lang=None, - data={}, - ) - - async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None: - """Add a phone number as a 3pid identifier - - Must be called on master. - - Args: - user_id: id of user - threepid: m.login.msisdn auth response - """ - try: - assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) - except SynapseError as ex: - if ex.errcode == Codes.MISSING_PARAM: - # This will only happen if the ID server returns a malformed response - logger.info("Can't add incomplete 3pid") - return None - raise - - await self._auth_handler.add_threepid( - user_id, - threepid["medium"], - threepid["address"], - threepid["validated_at"], - ) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index efe31e81f9..b1158ee77d 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py
@@ -188,13 +188,13 @@ class RelationsHandler: if include_original_event: # Do not bundle aggregations when retrieving the original event because # we want the content before relations are applied to it. - return_value["original_event"] = ( - await self._event_serializer.serialize_event( - event, - now, - bundle_aggregations=None, - config=serialize_options, - ) + return_value[ + "original_event" + ] = await self._event_serializer.serialize_event( + event, + now, + bundle_aggregations=None, + config=serialize_options, ) if next_token: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 2c6e672ede..1ccb6f7171 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -20,6 +20,7 @@ # """Contains functions for performing actions on rooms.""" + import itertools import logging import math @@ -467,17 +468,6 @@ class RoomCreationHandler: """ user_id = requester.user.to_string() - spam_check = await self._spam_checker_module_callbacks.user_may_create_room( - user_id - ) - if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: - raise SynapseError( - 403, - "You are not permitted to create rooms", - errcode=spam_check[0], - additional_fields=spam_check[1], - ) - creation_content: JsonDict = { "room_version": new_room_version.identifier, "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, @@ -584,6 +574,24 @@ class RoomCreationHandler: if current_power_level_int < needed_power_level: user_power_levels[user_id] = needed_power_level + # We construct what the body of a call to /createRoom would look like for passing + # to the spam checker. We don't include a preset here, as we expect the + # initial state to contain everything we need. + spam_check = await self._spam_checker_module_callbacks.user_may_create_room( + user_id, + { + "creation_content": creation_content, + "initial_state": list(initial_state.items()), + }, + ) + if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: + raise SynapseError( + 403, + "You are not permitted to create rooms", + errcode=spam_check[0], + additional_fields=spam_check[1], + ) + await self._send_events_for_new_room( requester, new_room_id, @@ -785,7 +793,7 @@ class RoomCreationHandler: if not is_requester_admin: spam_check = await self._spam_checker_module_callbacks.user_may_create_room( - user_id + user_id, config ) if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: raise SynapseError( @@ -900,11 +908,9 @@ class RoomCreationHandler: ) # Check whether this visibility value is blocked by a third party module - allowed_by_third_party_rules = ( - await ( - self._third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility - ) + allowed_by_third_party_rules = await ( + self._third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility ) ) if not allowed_by_third_party_rules: @@ -1754,7 +1760,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): ) events = list(room_events) - events.extend(e for evs, _ in room_to_events.values() for e in evs) + events.extend(e for evs, _, _ in room_to_events.values() for e in evs) # We know stream_ordering must be not None here, as its been # persisted, but mypy doesn't know that @@ -1807,7 +1813,7 @@ class RoomShutdownHandler: ] = None, ) -> Optional[ShutdownRoomResponse]: """ - Shuts down a room. Moves all local users and room aliases automatically + Shuts down a room. Moves all joined local users and room aliases automatically to a new room if `new_room_user_id` is set. Otherwise local users only leave the room without any information. @@ -1950,16 +1956,17 @@ class RoomShutdownHandler: # Join users to new room if new_room_user_id: - assert new_room_id is not None - await self.room_member_handler.update_membership( - requester=target_requester, - target=target_requester.user, - room_id=new_room_id, - action=Membership.JOIN, - content={}, - ratelimit=False, - require_consent=False, - ) + if membership == Membership.JOIN: + assert new_room_id is not None + await self.room_member_handler.update_membership( + requester=target_requester, + target=target_requester.user, + room_id=new_room_id, + action=Membership.JOIN, + content={}, + ratelimit=False, + require_consent=False, + ) result["kicked_users"].append(user_id) if update_result_fct: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 51b9772329..a3a7326d94 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -53,6 +53,7 @@ from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.push import ReplicationCopyPusherRestServlet from synapse.storage.databases.main.state_deltas import StateDelta +from synapse.storage.invite_rule import InviteRule from synapse.types import ( JsonDict, Requester, @@ -98,7 +99,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.federation_handler = hs.get_federation_handler() self.directory_handler = hs.get_directory_handler() - self.identity_handler = hs.get_identity_handler() self.registration_handler = hs.get_registration_handler() self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() @@ -122,7 +122,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): hs.get_module_api_callbacks().third_party_event_rules ) self._server_notices_mxid = self.config.servernotices.server_notices_mxid - self._enable_lookup = hs.config.registration.enable_3pid_lookup self.allow_per_room_profiles = self.config.server.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( @@ -158,6 +157,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): store=self.store, clock=self.clock, cfg=hs.config.ratelimiting.rc_invites_per_room, + ratelimit_callbacks=hs.get_module_api_callbacks().ratelimit, ) # Ratelimiter for invites, keyed by recipient (across all rooms, all @@ -166,6 +166,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): store=self.store, clock=self.clock, cfg=hs.config.ratelimiting.rc_invites_per_user, + ratelimit_callbacks=hs.get_module_api_callbacks().ratelimit, ) # Ratelimiter for invites, keyed by issuer (across all rooms, all @@ -174,6 +175,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): store=self.store, clock=self.clock, cfg=hs.config.ratelimiting.rc_invites_per_issuer, + ratelimit_callbacks=hs.get_module_api_callbacks().ratelimit, ) self._third_party_invite_limiter = Ratelimiter( @@ -912,6 +914,21 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): additional_fields=block_invite_result[1], ) + # check the invitee's configuration and apply rules. Admins on the server can bypass. + if not is_requester_admin: + invite_config = await self.store.get_invite_config_for_user(target_id) + rule = invite_config.get_invite_rule(requester.user.to_string()) + if rule == InviteRule.BLOCK: + logger.info( + f"Automatically rejecting invite from {target_id} due to the the invite filtering rules of {requester.user}" + ) + raise SynapseError( + 403, + "You are not permitted to invite this user.", + errcode=Codes.INVITE_BLOCKED, + ) + # InviteRule.IGNORE is handled at the sync layer. + # An empty prev_events list is allowed as long as the auth_event_ids are present if prev_event_ids is not None: return await self._local_membership_update( @@ -1190,6 +1207,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): origin_server_ts=origin_server_ts, ) + async def check_for_any_membership_in_room( + self, *, user_id: str, room_id: str + ) -> None: + """ + Check if the user has any membership in the room and raise error if not. + + Args: + user_id: The user to check. + room_id: The room to check. + + Raises: + AuthError if the user doesn't have any membership in the room. + """ + result = await self.store.get_local_current_membership_for_user_in_room( + user_id=user_id, room_id=room_id + ) + + if result is None or result == (None, None): + raise AuthError(403, f"User {user_id} has no membership in room {room_id}") + async def _should_perform_remote_join( self, user_id: str, @@ -1302,11 +1339,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If this is going to be a local join, additional information must # be included in the event content in order to efficiently validate # the event. - content[EventContentFields.AUTHORISING_USER] = ( - await self.event_auth_handler.get_user_which_could_invite( - room_id, - state_before_join, - ) + content[ + EventContentFields.AUTHORISING_USER + ] = await self.event_auth_handler.get_user_which_could_invite( + room_id, + state_before_join, ) return False, [] @@ -1415,9 +1452,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if requester is not None: sender = UserID.from_string(event.sender) - assert ( - sender == requester.user - ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user) + assert sender == requester.user, ( + "Sender (%s) must be same as requester (%s)" % (sender, requester.user) + ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: requester = types.create_requester(target_user) @@ -1572,230 +1609,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): return UserID.from_string(invite.sender) return None - async def do_3pid_invite( - self, - room_id: str, - inviter: UserID, - medium: str, - address: str, - id_server: str, - requester: Requester, - txn_id: Optional[str], - id_access_token: str, - prev_event_ids: Optional[List[str]] = None, - depth: Optional[int] = None, - ) -> Tuple[str, int]: - """Invite a 3PID to a room. - - Args: - room_id: The room to invite the 3PID to. - inviter: The user sending the invite. - medium: The 3PID's medium. - address: The 3PID's address. - id_server: The identity server to use. - requester: The user making the request. - txn_id: The transaction ID this is part of, or None if this is not - part of a transaction. - id_access_token: Identity server access token. - depth: Override the depth used to order the event in the DAG. - prev_event_ids: The event IDs to use as the prev events - Should normally be set to None, which will cause the depth to be calculated - based on the prev_events. - - Returns: - Tuple of event ID and stream ordering position - - Raises: - ShadowBanError if the requester has been shadow-banned. - """ - if self.config.server.block_non_admin_invites: - is_requester_admin = await self.auth.is_server_admin(requester) - if not is_requester_admin: - raise SynapseError( - 403, "Invites have been disabled on this server", Codes.FORBIDDEN - ) - - if requester.shadow_banned: - # We randomly sleep a bit just to annoy the requester. - await self.clock.sleep(random.randint(1, 10)) - raise ShadowBanError() - - # We need to rate limit *before* we send out any 3PID invites, so we - # can't just rely on the standard ratelimiting of events. - await self._third_party_invite_limiter.ratelimit(requester) - - can_invite = await self._third_party_event_rules.check_threepid_can_be_invited( - medium, address, room_id - ) - if not can_invite: - raise SynapseError( - 403, - "This third-party identifier can not be invited in this room", - Codes.FORBIDDEN, - ) - - if not self._enable_lookup: - raise SynapseError( - 403, "Looking up third-party identifiers is denied from this server" - ) - - invitee = await self.identity_handler.lookup_3pid( - id_server, medium, address, id_access_token - ) - - if invitee: - # Note that update_membership with an action of "invite" can raise - # a ShadowBanError, but this was done above already. - # We don't check the invite against the spamchecker(s) here (through - # user_may_invite) because we'll do it further down the line anyway (in - # update_membership_locked). - event_id, stream_id = await self.update_membership( - requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id - ) - else: - # Check if the spamchecker(s) allow this invite to go through. - spam_check = ( - await self._spam_checker_module_callbacks.user_may_send_3pid_invite( - inviter_userid=requester.user.to_string(), - medium=medium, - address=address, - room_id=room_id, - ) - ) - if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: - raise SynapseError( - 403, - "Cannot send threepid invite", - errcode=spam_check[0], - additional_fields=spam_check[1], - ) - - event, stream_id = await self._make_and_store_3pid_invite( - requester, - id_server, - medium, - address, - room_id, - inviter, - txn_id=txn_id, - id_access_token=id_access_token, - prev_event_ids=prev_event_ids, - depth=depth, - ) - event_id = event.event_id - - return event_id, stream_id - - async def _make_and_store_3pid_invite( - self, - requester: Requester, - id_server: str, - medium: str, - address: str, - room_id: str, - user: UserID, - txn_id: Optional[str], - id_access_token: str, - prev_event_ids: Optional[List[str]] = None, - depth: Optional[int] = None, - ) -> Tuple[EventBase, int]: - room_state = await self._storage_controllers.state.get_current_state( - room_id, - StateFilter.from_types( - [ - (EventTypes.Member, user.to_string()), - (EventTypes.CanonicalAlias, ""), - (EventTypes.Name, ""), - (EventTypes.Create, ""), - (EventTypes.JoinRules, ""), - (EventTypes.RoomAvatar, ""), - ] - ), - ) - - inviter_display_name = "" - inviter_avatar_url = "" - member_event = room_state.get((EventTypes.Member, user.to_string())) - if member_event: - inviter_display_name = member_event.content.get("displayname", "") - inviter_avatar_url = member_event.content.get("avatar_url", "") - - # if user has no display name, default to their MXID - if not inviter_display_name: - inviter_display_name = user.to_string() - - canonical_room_alias = "" - canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) - if canonical_alias_event: - canonical_room_alias = canonical_alias_event.content.get("alias", "") - - room_name = "" - room_name_event = room_state.get((EventTypes.Name, "")) - if room_name_event: - room_name = room_name_event.content.get("name", "") - - room_type = None - room_create_event = room_state.get((EventTypes.Create, "")) - if room_create_event: - room_type = room_create_event.content.get(EventContentFields.ROOM_TYPE) - - room_join_rules = "" - join_rules_event = room_state.get((EventTypes.JoinRules, "")) - if join_rules_event: - room_join_rules = join_rules_event.content.get("join_rule", "") - - room_avatar_url = "" - room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) - if room_avatar_event: - room_avatar_url = room_avatar_event.content.get("url", "") - - ( - token, - public_keys, - fallback_public_key, - display_name, - ) = await self.identity_handler.ask_id_server_for_third_party_invite( - requester=requester, - id_server=id_server, - medium=medium, - address=address, - room_id=room_id, - inviter_user_id=user.to_string(), - room_alias=canonical_room_alias, - room_avatar_url=room_avatar_url, - room_join_rules=room_join_rules, - room_name=room_name, - room_type=room_type, - inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url, - id_access_token=id_access_token, - ) - - ( - event, - stream_id, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.ThirdPartyInvite, - "content": { - "display_name": display_name, - "public_keys": public_keys, - # For backwards compatibility: - "key_validity_url": fallback_public_key["key_validity_url"], - "public_key": fallback_public_key["public_key"], - }, - "room_id": room_id, - "sender": user.to_string(), - "state_key": token, - }, - ratelimit=False, - txn_id=txn_id, - prev_event_ids=prev_event_ids, - depth=depth, - ) - return event, stream_id - async def _is_host_in_room(self, partial_current_state_ids: StateMap[str]) -> bool: """Returns whether the homeserver is in the room based on its current state. diff --git a/synapse/handlers/room_policy.py b/synapse/handlers/room_policy.py new file mode 100644
index 0000000000..3a83c4d6ec --- /dev/null +++ b/synapse/handlers/room_policy.py
@@ -0,0 +1,96 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright 2016-2021 The Matrix.org Foundation C.I.C. +# Copyright (C) 2023 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# +# + +import logging +from typing import TYPE_CHECKING + +from synapse.events import EventBase +from synapse.types.handlers.policy_server import RECOMMENDATION_OK +from synapse.util.stringutils import parse_and_validate_server_name + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RoomPolicyHandler: + def __init__(self, hs: "HomeServer"): + self._hs = hs + self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() + self._event_auth_handler = hs.get_event_auth_handler() + self._federation_client = hs.get_federation_client() + + async def is_event_allowed(self, event: EventBase) -> bool: + """Check if the given event is allowed in the room by the policy server. + + Note: This will *always* return True if the room's policy server is Synapse + itself. This is because Synapse can't be a policy server (currently). + + If no policy server is configured in the room, this returns True. Similarly, if + the policy server is invalid in any way (not joined, not a server, etc), this + returns True. + + If a valid and contactable policy server is configured in the room, this returns + True if that server suggests the event is not spammy, and False otherwise. + + Args: + event: The event to check. This should be a fully-formed PDU. + + Returns: + bool: True if the event is allowed in the room, False otherwise. + """ + policy_event = await self._storage_controllers.state.get_current_state_event( + event.room_id, "org.matrix.msc4284.policy", "" + ) + if not policy_event: + return True # no policy server == default allow + + policy_server = policy_event.content.get("via", "") + if policy_server is None or not isinstance(policy_server, str): + return True # no policy server == default allow + + if policy_server == self._hs.hostname: + return True # Synapse itself can't be a policy server (currently) + + try: + parse_and_validate_server_name(policy_server) + except ValueError: + return True # invalid policy server == default allow + + is_in_room = await self._event_auth_handler.is_host_in_room( + event.room_id, policy_server + ) + if not is_in_room: + return True # policy server not in room == default allow + + # At this point, the server appears valid and is in the room, so ask it to check + # the event. + recommendation = await self._federation_client.get_pdu_policy_recommendation( + policy_server, event + ) + if recommendation != RECOMMENDATION_OK: + logger.info( + "[POLICY] Policy server %s recommended not to allow event %s in room %s: %s", + policy_server, + event.event_id, + event.room_id, + recommendation, + ) + return False + + return True # default allow diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 720459f1e7..1c39cfed1b 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py
@@ -183,8 +183,13 @@ class RoomSummaryHandler: ) -> JsonDict: """See docstring for SpaceSummaryHandler.get_room_hierarchy.""" - # First of all, check that the room is accessible. - if not await self._is_local_room_accessible(requested_room_id, requester): + # If the room is available locally, quickly check that the user can access it. + local_room = await self._store.is_host_joined( + requested_room_id, self._server_name + ) + if local_room and not await self._is_local_room_accessible( + requested_room_id, requester + ): raise UnstableSpecAuthError( 403, "User %s not in room %s, and room previews are disabled" @@ -192,6 +197,22 @@ class RoomSummaryHandler: errcode=Codes.NOT_JOINED, ) + if not local_room: + room_hierarchy = await self._summarize_remote_room_hierarchy( + _RoomQueueEntry(requested_room_id, ()), + False, + ) + root_room_entry = room_hierarchy[0] + if not root_room_entry or not await self._is_remote_room_accessible( + requester, requested_room_id, root_room_entry.room + ): + raise UnstableSpecAuthError( + 403, + "User %s not in room %s, and room previews are disabled" + % (requester, requested_room_id), + errcode=Codes.NOT_JOINED, + ) + # If this is continuing a previous session, pull the persisted data. if from_token: try: @@ -679,23 +700,55 @@ class RoomSummaryHandler: """ # The API doesn't return the room version so assume that a # join rule of knock is valid. + join_rule = room.get("join_rule") + world_readable = room.get("world_readable") + + logger.warning( + "[EMMA] Checking if room %s is accessible to %s: join_rule=%s, world_readable=%s", + room_id, requester, join_rule, world_readable + ) + if ( - room.get("join_rule") - in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED) - or room.get("world_readable") is True + join_rule in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED) + or world_readable is True ): return True - elif not requester: + else: + logger.warning( + "[EMMA] Room %s is not accessible to %s: join_rule=%s, world_readable=%s, join_rule result=%s, world_readable result=%s", + room_id, requester, join_rule, world_readable, + join_rule in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED), + world_readable is True + ) + + if not requester: + logger.warning( + "[EMMA] No requester, so room %s is not accessible", + room_id + ) return False + # Check if the user is a member of any of the allowed rooms from the response. allowed_rooms = room.get("allowed_room_ids") + logger.warning( + "[EMMA] Checking if room %s is in allowed rooms for %s: join_rule=%s, allowed_rooms=%s", + requester, + room_id, + join_rule, + allowed_rooms + ) if allowed_rooms and isinstance(allowed_rooms, list): if await self._event_auth_handler.is_user_in_rooms( allowed_rooms, requester ): return True + logger.warning( + "[EMMA] Checking if room %s is accessble to %s via local state", + room_id, + requester + ) # Finally, check locally if we can access the room. The user might # already be in the room (if it was a child room), or there might be a # pending invite, etc. @@ -863,6 +916,10 @@ class RoomSummaryHandler: if not room_entry or not await self._is_remote_room_accessible( requester, room_entry.room_id, room_entry.room ): + logger.warning( + "[Emma] Room entry contents: %s", + room_entry.room if room_entry else None + ) raise NotFoundError("Room not found or is not accessible") room = dict(room_entry.room) diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py deleted file mode 100644
index 8ebd3d4ff9..0000000000 --- a/synapse/handlers/saml.py +++ /dev/null
@@ -1,524 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2019 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# -import logging -import re -from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple - -import attr -import saml2 -import saml2.response -from saml2.client import Saml2Client - -from synapse.api.errors import SynapseError -from synapse.config import ConfigError -from synapse.handlers.sso import MappingException, UserAttributes -from synapse.http.servlet import parse_string -from synapse.http.site import SynapseRequest -from synapse.module_api import ModuleApi -from synapse.types import ( - MXID_LOCALPART_ALLOWED_CHARACTERS, - UserID, - map_username_to_mxid_localpart, -) -from synapse.util.iterutils import chunk_seq - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -@attr.s(slots=True, auto_attribs=True) -class Saml2SessionData: - """Data we track about SAML2 sessions""" - - # time the session was created, in milliseconds - creation_time: int - # The user interactive authentication session ID associated with this SAML - # session (or None if this SAML session is for an initial login). - ui_auth_session_id: Optional[str] = None - - -class SamlHandler: - def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self.server_name = hs.hostname - self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) - self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid - - self._saml2_session_lifetime = hs.config.saml2.saml2_session_lifetime - self._grandfathered_mxid_source_attribute = ( - hs.config.saml2.saml2_grandfathered_mxid_source_attribute - ) - self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements - - # plugin to do custom mapping from saml response to mxid - self._user_mapping_provider = hs.config.saml2.saml2_user_mapping_provider_class( - hs.config.saml2.saml2_user_mapping_provider_config, - ModuleApi(hs, hs.get_auth_handler()), - ) - - # identifier for the external_ids table - self.idp_id = "saml" - - # user-facing name of this auth provider - self.idp_name = hs.config.saml2.idp_name - - # MXC URI for icon for this auth provider - self.idp_icon = hs.config.saml2.idp_icon - - # optional brand identifier for this auth provider - self.idp_brand = hs.config.saml2.idp_brand - - # a map from saml session id to Saml2SessionData object - self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {} - - self._sso_handler = hs.get_sso_handler() - self._sso_handler.register_identity_provider(self) - - async def handle_redirect_request( - self, - request: SynapseRequest, - client_redirect_url: Optional[bytes], - ui_auth_session_id: Optional[str] = None, - ) -> str: - """Handle an incoming request to /login/sso/redirect - - Args: - request: the incoming HTTP request - client_redirect_url: the URL that we should redirect the - client to after login (or None for UI Auth). - ui_auth_session_id: The session ID of the ongoing UI Auth (or - None if this is a login). - - Returns: - URL to redirect to - """ - if not client_redirect_url: - # Some SAML identity providers (e.g. Google) require a - # RelayState parameter on requests, so pass in a dummy redirect URL - # (which will never get used). - client_redirect_url = b"unused" - - reqid, info = self._saml_client.prepare_for_authenticate( - entityid=self._saml_idp_entityid, relay_state=client_redirect_url - ) - - # Since SAML sessions timeout it is useful to log when they were created. - logger.info("Initiating a new SAML session: %s" % (reqid,)) - - now = self.clock.time_msec() - self._outstanding_requests_dict[reqid] = Saml2SessionData( - creation_time=now, - ui_auth_session_id=ui_auth_session_id, - ) - - for key, value in info["headers"]: - if key == "Location": - return value - - # this shouldn't happen! - raise Exception("prepare_for_authenticate didn't return a Location header") - - async def handle_saml_response(self, request: SynapseRequest) -> None: - """Handle an incoming request to /_synapse/client/saml2/authn_response - - Args: - request: the incoming request from the browser. We'll - respond to it with a redirect. - - Returns: - Completes once we have handled the request. - """ - resp_bytes = parse_string(request, "SAMLResponse", required=True) - relay_state = parse_string(request, "RelayState", required=True) - - # expire outstanding sessions before parse_authn_request_response checks - # the dict. - self.expire_sessions() - - try: - saml2_auth = self._saml_client.parse_authn_request_response( - resp_bytes, - saml2.BINDING_HTTP_POST, - outstanding=self._outstanding_requests_dict, - ) - except saml2.response.UnsolicitedResponse as e: - # the pysaml2 library helpfully logs an ERROR here, but neglects to log - # the session ID. I don't really want to put the full text of the exception - # in the (user-visible) exception message, so let's log the exception here - # so we can track down the session IDs later. - logger.warning(str(e)) - self._sso_handler.render_error( - request, "unsolicited_response", "Unexpected SAML2 login." - ) - return - except Exception as e: - self._sso_handler.render_error( - request, - "invalid_response", - "Unable to parse SAML2 response: %s." % (e,), - ) - return - - if saml2_auth.not_signed: - self._sso_handler.render_error( - request, "unsigned_respond", "SAML2 response was not signed." - ) - return - - logger.debug("SAML2 response: %s", saml2_auth.origxml) - - await self._handle_authn_response(request, saml2_auth, relay_state) - - async def _handle_authn_response( - self, - request: SynapseRequest, - saml2_auth: saml2.response.AuthnResponse, - relay_state: str, - ) -> None: - """Handle an AuthnResponse, having parsed it from the request params - - Assumes that the signature on the response object has been checked. Maps - the user onto an MXID, registering them if necessary, and returns a response - to the browser. - - Args: - request: the incoming request from the browser. We'll respond to it with an - HTML page or a redirect - saml2_auth: the parsed AuthnResponse object - relay_state: the RelayState query param, which encodes the URI to rediret - back to - """ - - for assertion in saml2_auth.assertions: - # kibana limits the length of a log field, whereas this is all rather - # useful, so split it up. - count = 0 - for part in chunk_seq(str(assertion), 10000): - logger.info( - "SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part - ) - count += 1 - - logger.info("SAML2 mapped attributes: %s", saml2_auth.ava) - - current_session = self._outstanding_requests_dict.pop( - saml2_auth.in_response_to, None - ) - - # first check if we're doing a UIA - if current_session and current_session.ui_auth_session_id: - try: - remote_user_id = self._remote_id_from_saml_response(saml2_auth, None) - except MappingException as e: - logger.exception("Failed to extract remote user id from SAML response") - self._sso_handler.render_error(request, "mapping_error", str(e)) - return - - return await self._sso_handler.complete_sso_ui_auth_request( - self.idp_id, - remote_user_id, - current_session.ui_auth_session_id, - request, - ) - - # otherwise, we're handling a login request. - - # Ensure that the attributes of the logged in user meet the required - # attributes. - if not self._sso_handler.check_required_attributes( - request, saml2_auth.ava, self._saml2_attribute_requirements - ): - return - - # Call the mapper to register/login the user - try: - await self._complete_saml_login(saml2_auth, request, relay_state) - except MappingException as e: - logger.exception("Could not map user") - self._sso_handler.render_error(request, "mapping_error", str(e)) - - async def _complete_saml_login( - self, - saml2_auth: saml2.response.AuthnResponse, - request: SynapseRequest, - client_redirect_url: str, - ) -> None: - """ - Given a SAML response, complete the login flow - - Retrieves the remote user ID, registers the user if necessary, and serves - a redirect back to the client with a login-token. - - Args: - saml2_auth: The parsed SAML2 response. - request: The request to respond to - client_redirect_url: The redirect URL passed in by the client. - - Raises: - MappingException if there was a problem mapping the response to a user. - RedirectException: some mapping providers may raise this if they need - to redirect to an interstitial page. - """ - remote_user_id = self._remote_id_from_saml_response( - saml2_auth, client_redirect_url - ) - - async def saml_response_to_remapped_user_attributes( - failures: int, - ) -> UserAttributes: - """ - Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form. - - This is backwards compatibility for abstraction for the SSO handler. - """ - # Call the mapping provider. - result = self._user_mapping_provider.saml_response_to_user_attributes( - saml2_auth, failures, client_redirect_url - ) - # Remap some of the results. - return UserAttributes( - localpart=result.get("mxid_localpart"), - display_name=result.get("displayname"), - emails=result.get("emails", []), - ) - - async def grandfather_existing_users() -> Optional[str]: - # backwards-compatibility hack: see if there is an existing user with a - # suitable mapping from the uid - if ( - self._grandfathered_mxid_source_attribute - and self._grandfathered_mxid_source_attribute in saml2_auth.ava - ): - attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0] - user_id = UserID( - map_username_to_mxid_localpart(attrval), self.server_name - ).to_string() - - logger.debug( - "Looking for existing account based on mapped %s %s", - self._grandfathered_mxid_source_attribute, - user_id, - ) - - users = await self.store.get_users_by_id_case_insensitive(user_id) - if users: - registered_user_id = list(users.keys())[0] - logger.info("Grandfathering mapping to %s", registered_user_id) - return registered_user_id - - return None - - await self._sso_handler.complete_sso_login_request( - self.idp_id, - remote_user_id, - request, - client_redirect_url, - saml_response_to_remapped_user_attributes, - grandfather_existing_users, - ) - - def _remote_id_from_saml_response( - self, - saml2_auth: saml2.response.AuthnResponse, - client_redirect_url: Optional[str], - ) -> str: - """Extract the unique remote id from a SAML2 AuthnResponse - - Args: - saml2_auth: The parsed SAML2 response. - client_redirect_url: The redirect URL passed in by the client. - Returns: - remote user id - - Raises: - MappingException if there was an error extracting the user id - """ - # It's not obvious why we need to pass in the redirect URI to the mapping - # provider, but we do :/ - remote_user_id = self._user_mapping_provider.get_remote_user_id( - saml2_auth, client_redirect_url - ) - - if not remote_user_id: - raise MappingException( - "Failed to extract remote user id from SAML response" - ) - - return remote_user_id - - def expire_sessions(self) -> None: - expire_before = self.clock.time_msec() - self._saml2_session_lifetime - to_expire = set() - for reqid, data in self._outstanding_requests_dict.items(): - if data.creation_time < expire_before: - to_expire.add(reqid) - for reqid in to_expire: - logger.debug("Expiring session id %s", reqid) - del self._outstanding_requests_dict[reqid] - - -DOT_REPLACE_PATTERN = re.compile( - "[^%s]" % (re.escape("".join(MXID_LOCALPART_ALLOWED_CHARACTERS)),) -) - - -def dot_replace_for_mxid(username: str) -> str: - """Replace any characters which are not allowed in Matrix IDs with a dot.""" - username = username.lower() - username = DOT_REPLACE_PATTERN.sub(".", username) - - # regular mxids aren't allowed to start with an underscore either - username = re.sub("^_", "", username) - return username - - -MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = { - "hexencode": map_username_to_mxid_localpart, - "dotreplace": dot_replace_for_mxid, -} - - -@attr.s(auto_attribs=True) -class SamlConfig: - mxid_source_attribute: str - mxid_mapper: Callable[[str], str] - - -class DefaultSamlMappingProvider: - __version__ = "0.0.1" - - def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi): - """The default SAML user mapping provider - - Args: - parsed_config: Module configuration - module_api: module api proxy - """ - self._mxid_source_attribute = parsed_config.mxid_source_attribute - self._mxid_mapper = parsed_config.mxid_mapper - - self._grandfathered_mxid_source_attribute = ( - module_api._hs.config.saml2.saml2_grandfathered_mxid_source_attribute - ) - - def get_remote_user_id( - self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str - ) -> str: - """Extracts the remote user id from the SAML response""" - try: - return saml_response.ava["uid"][0] - except KeyError: - logger.warning("SAML2 response lacks a 'uid' attestation") - raise MappingException("'uid' not in SAML2 response") - - def saml_response_to_user_attributes( - self, - saml_response: saml2.response.AuthnResponse, - failures: int, - client_redirect_url: str, - ) -> dict: - """Maps some text from a SAML response to attributes of a new user - - Args: - saml_response: A SAML auth response object - - failures: How many times a call to this function with this - saml_response has resulted in a failure - - client_redirect_url: where the client wants to redirect to - - Returns: - A dict containing new user attributes. Possible keys: - * mxid_localpart (str): Required. The localpart of the user's mxid - * displayname (str): The displayname of the user - * emails (list[str]): Any emails for the user - """ - try: - mxid_source = saml_response.ava[self._mxid_source_attribute][0] - except KeyError: - logger.warning( - "SAML2 response lacks a '%s' attestation", - self._mxid_source_attribute, - ) - raise SynapseError( - 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) - ) - - # Use the configured mapper for this mxid_source - localpart = self._mxid_mapper(mxid_source) - - # Append suffix integer if last call to this function failed to produce - # a usable mxid. - localpart += str(failures) if failures else "" - - # Retrieve the display name from the saml response - # If displayname is None, the mxid_localpart will be used instead - displayname = saml_response.ava.get("displayName", [None])[0] - - # Retrieve any emails present in the saml response - emails = saml_response.ava.get("email", []) - - return { - "mxid_localpart": localpart, - "displayname": displayname, - "emails": emails, - } - - @staticmethod - def parse_config(config: dict) -> SamlConfig: - """Parse the dict provided by the homeserver's config - Args: - config: A dictionary containing configuration options for this provider - Returns: - A custom config object for this module - """ - # Parse config options and use defaults where necessary - mxid_source_attribute = config.get("mxid_source_attribute", "uid") - mapping_type = config.get("mxid_mapping", "hexencode") - - # Retrieve the associating mapping function - try: - mxid_mapper = MXID_MAPPER_MAP[mapping_type] - except KeyError: - raise ConfigError( - "saml2_config.user_mapping_provider.config: '%s' is not a valid " - "mxid_mapping value" % (mapping_type,) - ) - - return SamlConfig(mxid_source_attribute, mxid_mapper) - - @staticmethod - def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]: - """Returns the required attributes of a SAML - - Args: - config: A SamlConfig object containing configuration params for this provider - - Returns: - The first set equates to the saml auth response - attributes that are required for the module to function, whereas the - second set consists of those attributes which can be used if - available, but are not necessary - """ - return {"uid", config.mxid_source_attribute}, {"displayName", "email"} diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index a7d52fa648..1a71135d5f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py
@@ -423,9 +423,9 @@ class SearchHandler: } if search_result.room_groups and "room_id" in group_keys: - rooms_cat_res.setdefault("groups", {})[ - "room_id" - ] = search_result.room_groups + rooms_cat_res.setdefault("groups", {})["room_id"] = ( + search_result.room_groups + ) if sender_group and "sender" in group_keys: rooms_cat_res.setdefault("groups", {})["sender"] = sender_group diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py deleted file mode 100644
index 70cdb0721c..0000000000 --- a/synapse/handlers/send_email.py +++ /dev/null
@@ -1,230 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2021 The Matrix.org C.I.C. Foundation -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -import email.utils -import logging -from email.mime.multipart import MIMEMultipart -from email.mime.text import MIMEText -from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, Optional - -from pkg_resources import parse_version - -import twisted -from twisted.internet.defer import Deferred -from twisted.internet.endpoints import HostnameEndpoint -from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory -from twisted.internet.ssl import optionsForClientTLS -from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory -from twisted.protocols.tls import TLSMemoryBIOFactory - -from synapse.logging.context import make_deferred_yieldable -from synapse.types import ISynapseReactor - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - -_is_old_twisted = parse_version(twisted.__version__) < parse_version("21") - - -class _NoTLSESMTPSender(ESMTPSender): - """Extend ESMTPSender to disable TLS - - Unfortunately, before Twisted 21.2, ESMTPSender doesn't give an easy way to disable - TLS, so we override its internal method which it uses to generate a context factory. - """ - - def _getContextFactory(self) -> Optional[IOpenSSLContextFactory]: - return None - - -async def _sendmail( - reactor: ISynapseReactor, - smtphost: str, - smtpport: int, - from_addr: str, - to_addr: str, - msg_bytes: bytes, - username: Optional[bytes] = None, - password: Optional[bytes] = None, - require_auth: bool = False, - require_tls: bool = False, - enable_tls: bool = True, - force_tls: bool = False, -) -> None: - """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests - - Params: - reactor: reactor to use to make the outbound connection - smtphost: hostname to connect to - smtpport: port to connect to - from_addr: "From" address for email - to_addr: "To" address for email - msg_bytes: Message content - username: username to authenticate with, if auth is enabled - password: password to give when authenticating - require_auth: if auth is not offered, fail the request - require_tls: if TLS is not offered, fail the reqest - enable_tls: True to enable STARTTLS. If this is False and require_tls is True, - the request will fail. - force_tls: True to enable Implicit TLS. - """ - msg = BytesIO(msg_bytes) - d: "Deferred[object]" = Deferred() - - def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory: - return ESMTPSenderFactory( - username, - password, - from_addr, - to_addr, - msg, - d, - heloFallback=True, - requireAuthentication=require_auth, - requireTransportSecurity=require_tls, - **kwargs, - ) - - factory: IProtocolFactory - if _is_old_twisted: - # before twisted 21.2, we have to override the ESMTPSender protocol to disable - # TLS - factory = build_sender_factory() - - if not enable_tls: - factory.protocol = _NoTLSESMTPSender - else: - # for twisted 21.2 and later, there is a 'hostname' parameter which we should - # set to enable TLS. - factory = build_sender_factory(hostname=smtphost if enable_tls else None) - - if force_tls: - factory = TLSMemoryBIOFactory(optionsForClientTLS(smtphost), True, factory) - - endpoint = HostnameEndpoint( - reactor, smtphost, smtpport, timeout=30, bindAddress=None - ) - - await make_deferred_yieldable(endpoint.connect(factory)) - - await make_deferred_yieldable(d) - - -class SendEmailHandler: - def __init__(self, hs: "HomeServer"): - self.hs = hs - - self._reactor = hs.get_reactor() - - self._from = hs.config.email.email_notif_from - self._smtp_host = hs.config.email.email_smtp_host - self._smtp_port = hs.config.email.email_smtp_port - - user = hs.config.email.email_smtp_user - self._smtp_user = user.encode("utf-8") if user is not None else None - passwd = hs.config.email.email_smtp_pass - self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None - self._require_transport_security = hs.config.email.require_transport_security - self._enable_tls = hs.config.email.enable_smtp_tls - self._force_tls = hs.config.email.force_tls - - self._sendmail = _sendmail - - async def send_email( - self, - email_address: str, - subject: str, - app_name: str, - html: str, - text: str, - additional_headers: Optional[Dict[str, str]] = None, - ) -> None: - """Send a multipart email with the given information. - - Args: - email_address: The address to send the email to. - subject: The email's subject. - app_name: The app name to include in the From header. - html: The HTML content to include in the email. - text: The plain text content to include in the email. - additional_headers: A map of additional headers to include. - """ - try: - from_string = self._from % {"app": app_name} - except (KeyError, TypeError): - from_string = self._from - - raw_from = email.utils.parseaddr(from_string)[1] - raw_to = email.utils.parseaddr(email_address)[1] - - if raw_to == "": - raise RuntimeError("Invalid 'to' address") - - html_part = MIMEText(html, "html", "utf-8") - text_part = MIMEText(text, "plain", "utf-8") - - multipart_msg = MIMEMultipart("alternative") - multipart_msg["Subject"] = subject - multipart_msg["From"] = from_string - multipart_msg["To"] = email_address - multipart_msg["Date"] = email.utils.formatdate() - multipart_msg["Message-ID"] = email.utils.make_msgid() - - # Discourage automatic responses to Synapse's emails. - # Per RFC 3834, automatic responses should not be sent if the "Auto-Submitted" - # header is present with any value other than "no". See - # https://www.rfc-editor.org/rfc/rfc3834.html#section-5.1 - multipart_msg["Auto-Submitted"] = "auto-generated" - # Also include a Microsoft-Exchange specific header: - # https://learn.microsoft.com/en-us/openspecs/exchange_server_protocols/ms-oxcmail/ced68690-498a-4567-9d14-5c01f974d8b1 - # which suggests it can take the value "All" to "suppress all auto-replies", - # or a comma separated list of auto-reply classes to suppress. - # The following stack overflow question has a little more context: - # https://stackoverflow.com/a/25324691/5252017 - # https://stackoverflow.com/a/61646381/5252017 - multipart_msg["X-Auto-Response-Suppress"] = "All" - - if additional_headers: - for header, value in additional_headers.items(): - multipart_msg[header] = value - - multipart_msg.attach(text_part) - multipart_msg.attach(html_part) - - logger.info("Sending email to %s" % email_address) - - await self._sendmail( - self._reactor, - self._smtp_host, - self._smtp_port, - raw_from, - raw_to, - multipart_msg.as_string().encode("utf8"), - username=self._smtp_user, - password=self._smtp_pass, - require_auth=self._smtp_user is not None, - require_tls=self._require_transport_security, - enable_tls=self._enable_tls, - force_tls=self._force_tls, - ) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 29cc03d71d..94301add9e 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py
@@ -36,10 +36,17 @@ class SetPasswordHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() - # This can only be instantiated on the main process. - device_handler = hs.get_device_handler() - assert isinstance(device_handler, DeviceHandler) - self._device_handler = device_handler + + # We don't need the device handler if password changing is disabled. + # This allows us to instantiate the SetPasswordHandler on the workers + # that have admin APIs for MAS + if self._auth_handler.can_change_password(): + # This can only be instantiated on the main process. + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler: Optional[DeviceHandler] = device_handler + else: + self._device_handler = None async def set_password( self, @@ -51,6 +58,9 @@ class SetPasswordHandler: if not self._auth_handler.can_change_password(): raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) + # We should have this available only if password changing is enabled. + assert self._device_handler is not None + try: await self.store.user_set_password_hash(user_id, password_hash) except StoreError as e: diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py deleted file mode 100644
index 18a96843be..0000000000 --- a/synapse/handlers/sliding_sync.py +++ /dev/null
@@ -1,3158 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# -import enum -import logging -from enum import Enum -from itertools import chain -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Final, - List, - Literal, - Mapping, - Optional, - Sequence, - Set, - Tuple, - Union, -) - -import attr -from immutabledict import immutabledict -from typing_extensions import assert_never - -from synapse.api.constants import ( - AccountDataTypes, - Direction, - EventContentFields, - EventTypes, - Membership, -) -from synapse.api.errors import SlidingSyncUnknownPosition -from synapse.events import EventBase, StrippedStateEvent -from synapse.events.utils import parse_stripped_state_event, strip_event -from synapse.handlers.relations import BundledAggregations -from synapse.logging.opentracing import ( - SynapseTags, - log_kv, - set_tag, - start_active_span, - tag_args, - trace, -) -from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary -from synapse.storage.databases.main.state import ( - ROOM_UNKNOWN_SENTINEL, - Sentinel as StateSentinel, -) -from synapse.storage.databases.main.stream import ( - CurrentStateDeltaMembership, - PaginateFunction, -) -from synapse.storage.roommember import MemberSummary -from synapse.types import ( - DeviceListUpdates, - JsonDict, - JsonMapping, - MultiWriterStreamToken, - MutableStateMap, - PersistedEventPosition, - Requester, - RoomStreamToken, - SlidingSyncStreamToken, - StateMap, - StrCollection, - StreamKeyType, - StreamToken, - UserID, -) -from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult -from synapse.types.state import StateFilter -from synapse.util.async_helpers import concurrently_execute -from synapse.visibility import filter_events_for_client - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class Sentinel(enum.Enum): - # defining a sentinel in this way allows mypy to correctly handle the - # type of a dictionary lookup and subsequent type narrowing. - UNSET_SENTINEL = object() - - -# The event types that clients should consider as new activity. -DEFAULT_BUMP_EVENT_TYPES = { - EventTypes.Create, - EventTypes.Message, - EventTypes.Encrypted, - EventTypes.Sticker, - EventTypes.CallInvite, - EventTypes.PollStart, - EventTypes.LiveLocationShareStart, -} - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _RoomMembershipForUser: - """ - Attributes: - room_id: The room ID of the membership event - event_id: The event ID of the membership event - event_pos: The stream position of the membership event - membership: The membership state of the user in the room - sender: The person who sent the membership event - newly_joined: Whether the user newly joined the room during the given token - range and is still joined to the room at the end of this range. - newly_left: Whether the user newly left (or kicked) the room during the given - token range and is still "leave" at the end of this range. - is_dm: Whether this user considers this room as a direct-message (DM) room - """ - - room_id: str - # Optional because state resets can affect room membership without a corresponding event. - event_id: Optional[str] - # Even during a state reset which removes the user from the room, we expect this to - # be set because `current_state_delta_stream` will note the position that the reset - # happened. - event_pos: PersistedEventPosition - # Even during a state reset which removes the user from the room, we expect this to - # be set to `LEAVE` because we can make that assumption based on the situaton (see - # `get_current_state_delta_membership_changes_for_user(...)`) - membership: str - # Optional because state resets can affect room membership without a corresponding event. - sender: Optional[str] - newly_joined: bool - newly_left: bool - is_dm: bool - - def copy_and_replace(self, **kwds: Any) -> "_RoomMembershipForUser": - return attr.evolve(self, **kwds) - - -def filter_membership_for_sync( - *, user_id: str, room_membership_for_user: _RoomMembershipForUser -) -> bool: - """ - Returns True if the membership event should be included in the sync response, - otherwise False. - - Attributes: - user_id: The user ID that the membership applies to - room_membership_for_user: Membership information for the user in the room - """ - - membership = room_membership_for_user.membership - sender = room_membership_for_user.sender - newly_left = room_membership_for_user.newly_left - - # We want to allow everything except rooms the user has left unless `newly_left` - # because we want everything that's *still* relevant to the user. We include - # `newly_left` rooms because the last event that the user should see is their own - # leave event. - # - # A leave != kick. This logic includes kicks (leave events where the sender is not - # the same user). - # - # When `sender=None`, it means that a state reset happened that removed the user - # from the room without a corresponding leave event. We can just remove the rooms - # since they are no longer relevant to the user but will still appear if they are - # `newly_left`. - return ( - # Anything except leave events - membership != Membership.LEAVE - # Unless... - or newly_left - # Allow kicks - or (membership == Membership.LEAVE and sender not in (user_id, None)) - ) - - -# We can't freeze this class because we want to update it in place with the -# de-duplicated data. -@attr.s(slots=True, auto_attribs=True) -class RoomSyncConfig: - """ - Holds the config for what data we should fetch for a room in the sync response. - - Attributes: - timeline_limit: The maximum number of events to return in the timeline. - - required_state_map: Map from state event type to state_keys requested for the - room. The values are close to `StateKey` but actually use a syntax where you - can provide `*` wildcard and `$LAZY` for lazy-loading room members. - """ - - timeline_limit: int - required_state_map: Dict[str, Set[str]] - - @classmethod - def from_room_config( - cls, - room_params: SlidingSyncConfig.CommonRoomParameters, - ) -> "RoomSyncConfig": - """ - Create a `RoomSyncConfig` from a `SlidingSyncList`/`RoomSubscription` config. - - Args: - room_params: `SlidingSyncConfig.SlidingSyncList` or `SlidingSyncConfig.RoomSubscription` - """ - required_state_map: Dict[str, Set[str]] = {} - for ( - state_type, - state_key, - ) in room_params.required_state: - # If we already have a wildcard for this specific `state_key`, we don't need - # to add it since the wildcard already covers it. - if state_key in required_state_map.get(StateValues.WILDCARD, set()): - continue - - # If we already have a wildcard `state_key` for this `state_type`, we don't need - # to add anything else - if StateValues.WILDCARD in required_state_map.get(state_type, set()): - continue - - # If we're getting wildcards for the `state_type` and `state_key`, that's - # all that matters so get rid of any other entries - if state_type == StateValues.WILDCARD and state_key == StateValues.WILDCARD: - required_state_map = {StateValues.WILDCARD: {StateValues.WILDCARD}} - # We can break, since we don't need to add anything else - break - - # If we're getting a wildcard for the `state_type`, get rid of any other - # entries with the same `state_key`, since the wildcard will cover it already. - elif state_type == StateValues.WILDCARD: - # Get rid of any entries that match the `state_key` - # - # Make a copy so we don't run into an error: `dictionary changed size - # during iteration`, when we remove items - for ( - existing_state_type, - existing_state_key_set, - ) in list(required_state_map.items()): - # Make a copy so we don't run into an error: `Set changed size during - # iteration`, when we filter out and remove items - for existing_state_key in existing_state_key_set.copy(): - if existing_state_key == state_key: - existing_state_key_set.remove(state_key) - - # If we've the left the `set()` empty, remove it from the map - if existing_state_key_set == set(): - required_state_map.pop(existing_state_type, None) - - # If we're getting a wildcard `state_key`, get rid of any other state_keys - # for this `state_type` since the wildcard will cover it already. - if state_key == StateValues.WILDCARD: - required_state_map[state_type] = {state_key} - # Otherwise, just add it to the set - else: - if required_state_map.get(state_type) is None: - required_state_map[state_type] = {state_key} - else: - required_state_map[state_type].add(state_key) - - return cls( - timeline_limit=room_params.timeline_limit, - required_state_map=required_state_map, - ) - - def deep_copy(self) -> "RoomSyncConfig": - required_state_map: Dict[str, Set[str]] = { - state_type: state_key_set.copy() - for state_type, state_key_set in self.required_state_map.items() - } - - return RoomSyncConfig( - timeline_limit=self.timeline_limit, - required_state_map=required_state_map, - ) - - def combine_room_sync_config( - self, other_room_sync_config: "RoomSyncConfig" - ) -> None: - """ - Combine this `RoomSyncConfig` with another `RoomSyncConfig` and take the - superset union of the two. - """ - # Take the highest timeline limit - if self.timeline_limit < other_room_sync_config.timeline_limit: - self.timeline_limit = other_room_sync_config.timeline_limit - - # Union the required state - for ( - state_type, - state_key_set, - ) in other_room_sync_config.required_state_map.items(): - # If we already have a wildcard for everything, we don't need to add - # anything else - if StateValues.WILDCARD in self.required_state_map.get( - StateValues.WILDCARD, set() - ): - break - - # If we already have a wildcard `state_key` for this `state_type`, we don't need - # to add anything else - if StateValues.WILDCARD in self.required_state_map.get(state_type, set()): - continue - - # If we're getting wildcards for the `state_type` and `state_key`, that's - # all that matters so get rid of any other entries - if ( - state_type == StateValues.WILDCARD - and StateValues.WILDCARD in state_key_set - ): - self.required_state_map = {state_type: {StateValues.WILDCARD}} - # We can break, since we don't need to add anything else - break - - for state_key in state_key_set: - # If we already have a wildcard for this specific `state_key`, we don't need - # to add it since the wildcard already covers it. - if state_key in self.required_state_map.get( - StateValues.WILDCARD, set() - ): - continue - - # If we're getting a wildcard for the `state_type`, get rid of any other - # entries with the same `state_key`, since the wildcard will cover it already. - if state_type == StateValues.WILDCARD: - # Get rid of any entries that match the `state_key` - # - # Make a copy so we don't run into an error: `dictionary changed size - # during iteration`, when we remove items - for existing_state_type, existing_state_key_set in list( - self.required_state_map.items() - ): - # Make a copy so we don't run into an error: `Set changed size during - # iteration`, when we filter out and remove items - for existing_state_key in existing_state_key_set.copy(): - if existing_state_key == state_key: - existing_state_key_set.remove(state_key) - - # If we've the left the `set()` empty, remove it from the map - if existing_state_key_set == set(): - self.required_state_map.pop(existing_state_type, None) - - # If we're getting a wildcard `state_key`, get rid of any other state_keys - # for this `state_type` since the wildcard will cover it already. - if state_key == StateValues.WILDCARD: - self.required_state_map[state_type] = {state_key} - break - # Otherwise, just add it to the set - else: - if self.required_state_map.get(state_type) is None: - self.required_state_map[state_type] = {state_key} - else: - self.required_state_map[state_type].add(state_key) - - -class StateValues: - """ - Understood values of the (type, state_key) tuple in `required_state`. - """ - - # Include all state events of the given type - WILDCARD: Final = "*" - # Lazy-load room membership events (include room membership events for any event - # `sender` in the timeline). We only give special meaning to this value when it's a - # `state_key`. - LAZY: Final = "$LAZY" - # Subsitute with the requester's user ID. Typically used by clients to get - # the user's membership. - ME: Final = "$ME" - - -class SlidingSyncHandler: - def __init__(self, hs: "HomeServer"): - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - self.auth_blocking = hs.get_auth_blocking() - self.notifier = hs.get_notifier() - self.event_sources = hs.get_event_sources() - self.relations_handler = hs.get_relations_handler() - self.device_handler = hs.get_device_handler() - self.push_rules_handler = hs.get_push_rules_handler() - self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync - - self.connection_store = SlidingSyncConnectionStore() - - async def wait_for_sync_for_user( - self, - requester: Requester, - sync_config: SlidingSyncConfig, - from_token: Optional[SlidingSyncStreamToken] = None, - timeout_ms: int = 0, - ) -> SlidingSyncResult: - """ - Get the sync for a client if we have new data for it now. Otherwise - wait for new data to arrive on the server. If the timeout expires, then - return an empty sync result. - - Args: - requester: The user making the request - sync_config: Sync configuration - from_token: The point in the stream to sync from. Token of the end of the - previous batch. May be `None` if this is the initial sync request. - timeout_ms: The time in milliseconds to wait for new data to arrive. If 0, - we will immediately but there might not be any new data so we just return an - empty response. - """ - # If the user is not part of the mau group, then check that limits have - # not been exceeded (if not part of the group by this point, almost certain - # auth_blocking will occur) - await self.auth_blocking.check_auth_blocking(requester=requester) - - # If we're working with a user-provided token, we need to make sure to wait for - # this worker to catch up with the token so we don't skip past any incoming - # events or future events if the user is nefariously, manually modifying the - # token. - if from_token is not None: - # We need to make sure this worker has caught up with the token. If - # this returns false, it means we timed out waiting, and we should - # just return an empty response. - before_wait_ts = self.clock.time_msec() - if not await self.notifier.wait_for_stream_token(from_token.stream_token): - logger.warning( - "Timed out waiting for worker to catch up. Returning empty response" - ) - return SlidingSyncResult.empty(from_token) - - # If we've spent significant time waiting to catch up, take it off - # the timeout. - after_wait_ts = self.clock.time_msec() - if after_wait_ts - before_wait_ts > 1_000: - timeout_ms -= after_wait_ts - before_wait_ts - timeout_ms = max(timeout_ms, 0) - - # We're going to respond immediately if the timeout is 0 or if this is an - # initial sync (without a `from_token`) so we can avoid calling - # `notifier.wait_for_events()`. - if timeout_ms == 0 or from_token is None: - now_token = self.event_sources.get_current_token() - result = await self.current_sync_for_user( - sync_config, - from_token=from_token, - to_token=now_token, - ) - else: - # Otherwise, we wait for something to happen and report it to the user. - async def current_sync_callback( - before_token: StreamToken, after_token: StreamToken - ) -> SlidingSyncResult: - return await self.current_sync_for_user( - sync_config, - from_token=from_token, - to_token=after_token, - ) - - result = await self.notifier.wait_for_events( - sync_config.user.to_string(), - timeout_ms, - current_sync_callback, - from_token=from_token.stream_token, - ) - - return result - - @trace - async def current_sync_for_user( - self, - sync_config: SlidingSyncConfig, - to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken] = None, - ) -> SlidingSyncResult: - """ - Generates the response body of a Sliding Sync result, represented as a - `SlidingSyncResult`. - - We fetch data according to the token range (> `from_token` and <= `to_token`). - - Args: - sync_config: Sync configuration - to_token: The point in the stream to sync up to. - from_token: The point in the stream to sync from. Token of the end of the - previous batch. May be `None` if this is the initial sync request. - """ - user_id = sync_config.user.to_string() - app_service = self.store.get_app_service_by_user_id(user_id) - if app_service: - # We no longer support AS users using /sync directly. - # See https://github.com/matrix-org/matrix-doc/issues/1144 - raise NotImplementedError() - - if from_token: - # Check that we recognize the connection position, if not tell the - # clients that they need to start again. - # - # If we don't do this and the client asks for the full range of - # rooms, we end up sending down all rooms and their state from - # scratch (which can be very slow). By expiring the connection we - # allow the client a chance to do an initial request with a smaller - # range of rooms to get them some results sooner but will end up - # taking the same amount of time (more with round-trips and - # re-processing) in the end to get everything again. - if not await self.connection_store.is_valid_token( - sync_config, from_token.connection_position - ): - raise SlidingSyncUnknownPosition() - - await self.connection_store.mark_token_seen( - sync_config=sync_config, - from_token=from_token, - ) - - # Get all of the room IDs that the user should be able to see in the sync - # response - has_lists = sync_config.lists is not None and len(sync_config.lists) > 0 - has_room_subscriptions = ( - sync_config.room_subscriptions is not None - and len(sync_config.room_subscriptions) > 0 - ) - if has_lists or has_room_subscriptions: - room_membership_for_user_map = ( - await self.get_room_membership_for_user_at_to_token( - user=sync_config.user, - to_token=to_token, - from_token=from_token.stream_token if from_token else None, - ) - ) - - # Assemble sliding window lists - lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {} - # Keep track of the rooms that we can display and need to fetch more info about - relevant_room_map: Dict[str, RoomSyncConfig] = {} - # The set of room IDs of all rooms that could appear in any list. These - # include rooms that are outside the list ranges. - all_rooms: Set[str] = set() - if has_lists and sync_config.lists is not None: - with start_active_span("assemble_sliding_window_lists"): - sync_room_map = await self.filter_rooms_relevant_for_sync( - user=sync_config.user, - room_membership_for_user_map=room_membership_for_user_map, - ) - - for list_key, list_config in sync_config.lists.items(): - # Apply filters - filtered_sync_room_map = sync_room_map - if list_config.filters is not None: - filtered_sync_room_map = await self.filter_rooms( - sync_config.user, - sync_room_map, - list_config.filters, - to_token, - ) - - # Find which rooms are partially stated and may need to be filtered out - # depending on the `required_state` requested (see below). - partial_state_room_map = ( - await self.store.is_partial_state_room_batched( - filtered_sync_room_map.keys() - ) - ) - - # Since creating the `RoomSyncConfig` takes some work, let's just do it - # once and make a copy whenever we need it. - room_sync_config = RoomSyncConfig.from_room_config(list_config) - membership_state_keys = room_sync_config.required_state_map.get( - EventTypes.Member - ) - # Also see `StateFilter.must_await_full_state(...)` for comparison - lazy_loading = ( - membership_state_keys is not None - and StateValues.LAZY in membership_state_keys - ) - - if not lazy_loading: - # Exclude partially-stated rooms unless the `required_state` - # only has `["m.room.member", "$LAZY"]` for membership - # (lazy-loading room members). - filtered_sync_room_map = { - room_id: room - for room_id, room in filtered_sync_room_map.items() - if not partial_state_room_map.get(room_id) - } - - all_rooms.update(filtered_sync_room_map) - - # Sort the list - sorted_room_info = await self.sort_rooms( - filtered_sync_room_map, to_token - ) - - ops: List[SlidingSyncResult.SlidingWindowList.Operation] = [] - if list_config.ranges: - for range in list_config.ranges: - room_ids_in_list: List[str] = [] - - # We're going to loop through the sorted list of rooms starting - # at the range start index and keep adding rooms until we fill - # up the range or run out of rooms. - # - # Both sides of range are inclusive so we `+ 1` - max_num_rooms = range[1] - range[0] + 1 - for room_membership in sorted_room_info[range[0] :]: - room_id = room_membership.room_id - - if len(room_ids_in_list) >= max_num_rooms: - break - - # Take the superset of the `RoomSyncConfig` for each room. - # - # Update our `relevant_room_map` with the room we're going - # to display and need to fetch more info about. - existing_room_sync_config = relevant_room_map.get( - room_id - ) - if existing_room_sync_config is not None: - existing_room_sync_config.combine_room_sync_config( - room_sync_config - ) - else: - # Make a copy so if we modify it later, it doesn't - # affect all references. - relevant_room_map[room_id] = ( - room_sync_config.deep_copy() - ) - - room_ids_in_list.append(room_id) - - ops.append( - SlidingSyncResult.SlidingWindowList.Operation( - op=OperationType.SYNC, - range=range, - room_ids=room_ids_in_list, - ) - ) - - lists[list_key] = SlidingSyncResult.SlidingWindowList( - count=len(sorted_room_info), - ops=ops, - ) - - # Handle room subscriptions - if has_room_subscriptions and sync_config.room_subscriptions is not None: - with start_active_span("assemble_room_subscriptions"): - for ( - room_id, - room_subscription, - ) in sync_config.room_subscriptions.items(): - room_membership_for_user_at_to_token = ( - await self.check_room_subscription_allowed_for_user( - room_id=room_id, - room_membership_for_user_map=room_membership_for_user_map, - to_token=to_token, - ) - ) - - # Skip this room if the user isn't allowed to see it - if not room_membership_for_user_at_to_token: - continue - - all_rooms.add(room_id) - - room_membership_for_user_map[room_id] = ( - room_membership_for_user_at_to_token - ) - - # Take the superset of the `RoomSyncConfig` for each room. - # - # Update our `relevant_room_map` with the room we're going to display - # and need to fetch more info about. - room_sync_config = RoomSyncConfig.from_room_config( - room_subscription - ) - existing_room_sync_config = relevant_room_map.get(room_id) - if existing_room_sync_config is not None: - existing_room_sync_config.combine_room_sync_config( - room_sync_config - ) - else: - relevant_room_map[room_id] = room_sync_config - - # Fetch room data - rooms: Dict[str, SlidingSyncResult.RoomResult] = {} - - # Filter out rooms that haven't received updates and we've sent down - # previously. - # Keep track of the rooms that we're going to display and need to fetch more info about - relevant_rooms_to_send_map = relevant_room_map - with start_active_span("filter_relevant_rooms_to_send"): - if from_token: - rooms_should_send = set() - - # First we check if there are rooms that match a list/room - # subscription and have updates we need to send (i.e. either because - # we haven't sent the room down, or we have but there are missing - # updates). - for room_id in relevant_room_map: - status = await self.connection_store.have_sent_room( - sync_config, - from_token.connection_position, - room_id, - ) - if ( - # The room was never sent down before so the client needs to know - # about it regardless of any updates. - status.status == HaveSentRoomFlag.NEVER - # `PREVIOUSLY` literally means the "room was sent down before *AND* - # there are updates we haven't sent down" so we already know this - # room has updates. - or status.status == HaveSentRoomFlag.PREVIOUSLY - ): - rooms_should_send.add(room_id) - elif status.status == HaveSentRoomFlag.LIVE: - # We know that we've sent all updates up until `from_token`, - # so we just need to check if there have been updates since - # then. - pass - else: - assert_never(status.status) - - # We only need to check for new events since any state changes - # will also come down as new events. - rooms_that_have_updates = self.store.get_rooms_that_might_have_updates( - relevant_room_map.keys(), from_token.stream_token.room_key - ) - rooms_should_send.update(rooms_that_have_updates) - relevant_rooms_to_send_map = { - room_id: room_sync_config - for room_id, room_sync_config in relevant_room_map.items() - if room_id in rooms_should_send - } - - @trace - @tag_args - async def handle_room(room_id: str) -> None: - room_sync_result = await self.get_room_sync_data( - sync_config=sync_config, - room_id=room_id, - room_sync_config=relevant_rooms_to_send_map[room_id], - room_membership_for_user_at_to_token=room_membership_for_user_map[ - room_id - ], - from_token=from_token, - to_token=to_token, - ) - - # Filter out empty room results during incremental sync - if room_sync_result or not from_token: - rooms[room_id] = room_sync_result - - if relevant_rooms_to_send_map: - with start_active_span("sliding_sync.generate_room_entries"): - await concurrently_execute(handle_room, relevant_rooms_to_send_map, 10) - - extensions = await self.get_extensions_response( - sync_config=sync_config, - actual_lists=lists, - # We're purposely using `relevant_room_map` instead of - # `relevant_rooms_to_send_map` here. This needs to be all room_ids we could - # send regardless of whether they have an event update or not. The - # extensions care about more than just normal events in the rooms (like - # account data, read receipts, typing indicators, to-device messages, etc). - actual_room_ids=set(relevant_room_map.keys()), - actual_room_response_map=rooms, - from_token=from_token, - to_token=to_token, - ) - - if has_lists or has_room_subscriptions: - # We now calculate if any rooms outside the range have had updates, - # which we are not sending down. - # - # We *must* record rooms that have had updates, but it is also fine - # to record rooms as having updates even if there might not actually - # be anything new for the user (e.g. due to event filters, events - # having happened after the user left, etc). - unsent_room_ids = [] - if from_token: - # The set of rooms that the client (may) care about, but aren't - # in any list range (or subscribed to). - missing_rooms = all_rooms - relevant_room_map.keys() - - # We now just go and try fetching any events in the above rooms - # to see if anything has happened since the `from_token`. - # - # TODO: Replace this with something faster. When we land the - # sliding sync tables that record the most recent event - # positions we can use that. - missing_event_map_by_room = ( - await self.store.get_room_events_stream_for_rooms( - room_ids=missing_rooms, - from_key=to_token.room_key, - to_key=from_token.stream_token.room_key, - limit=1, - ) - ) - unsent_room_ids = list(missing_event_map_by_room) - - connection_position = await self.connection_store.record_rooms( - sync_config=sync_config, - from_token=from_token, - sent_room_ids=relevant_rooms_to_send_map.keys(), - unsent_room_ids=unsent_room_ids, - ) - elif from_token: - connection_position = from_token.connection_position - else: - # Initial sync without a `from_token` starts at `0` - connection_position = 0 - - sliding_sync_result = SlidingSyncResult( - next_pos=SlidingSyncStreamToken(to_token, connection_position), - lists=lists, - rooms=rooms, - extensions=extensions, - ) - - # Make it easy to find traces for syncs that aren't empty - set_tag(SynapseTags.RESULT_PREFIX + "result", bool(sliding_sync_result)) - set_tag(SynapseTags.FUNC_ARG_PREFIX + "sync_config.user", user_id) - - return sliding_sync_result - - @trace - async def get_room_membership_for_user_at_to_token( - self, - user: UserID, - to_token: StreamToken, - from_token: Optional[StreamToken], - ) -> Dict[str, _RoomMembershipForUser]: - """ - Fetch room IDs that the user has had membership in (the full room list including - long-lost left rooms that will be filtered, sorted, and sliced). - - We're looking for rooms where the user has had any sort of membership in the - token range (> `from_token` and <= `to_token`) - - In order for bans/kicks to not show up, you need to `/forget` those rooms. This - doesn't modify the event itself though and only adds the `forgotten` flag to the - `room_memberships` table in Synapse. There isn't a way to tell when a room was - forgotten at the moment so we can't factor it into the token range. - - Args: - user: User to fetch rooms for - to_token: The token to fetch rooms up to. - from_token: The point in the stream to sync from. - - Returns: - A dictionary of room IDs that the user has had membership in along with - membership information in that room at the time of `to_token`. - """ - user_id = user.to_string() - - # First grab a current snapshot rooms for the user - # (also handles forgotten rooms) - room_for_user_list = await self.store.get_rooms_for_local_user_where_membership_is( - user_id=user_id, - # We want to fetch any kind of membership (joined and left rooms) in order - # to get the `event_pos` of the latest room membership event for the - # user. - membership_list=Membership.LIST, - excluded_rooms=self.rooms_to_exclude_globally, - ) - - # If the user has never joined any rooms before, we can just return an empty list - if not room_for_user_list: - return {} - - # Our working list of rooms that can show up in the sync response - sync_room_id_set = { - # Note: The `room_for_user` we're assigning here will need to be fixed up - # (below) because they are potentially from the current snapshot time - # instead from the time of the `to_token`. - room_for_user.room_id: _RoomMembershipForUser( - room_id=room_for_user.room_id, - event_id=room_for_user.event_id, - event_pos=room_for_user.event_pos, - membership=room_for_user.membership, - sender=room_for_user.sender, - # We will update these fields below to be accurate - newly_joined=False, - newly_left=False, - is_dm=False, - ) - for room_for_user in room_for_user_list - } - - # Get the `RoomStreamToken` that represents the spot we queried up to when we got - # our membership snapshot from `get_rooms_for_local_user_where_membership_is()`. - # - # First, we need to get the max stream_ordering of each event persister instance - # that we queried events from. - instance_to_max_stream_ordering_map: Dict[str, int] = {} - for room_for_user in room_for_user_list: - instance_name = room_for_user.event_pos.instance_name - stream_ordering = room_for_user.event_pos.stream - - current_instance_max_stream_ordering = ( - instance_to_max_stream_ordering_map.get(instance_name) - ) - if ( - current_instance_max_stream_ordering is None - or stream_ordering > current_instance_max_stream_ordering - ): - instance_to_max_stream_ordering_map[instance_name] = stream_ordering - - # Then assemble the `RoomStreamToken` - min_stream_pos = min(instance_to_max_stream_ordering_map.values()) - membership_snapshot_token = RoomStreamToken( - # Minimum position in the `instance_map` - stream=min_stream_pos, - instance_map=immutabledict( - { - instance_name: stream_pos - for instance_name, stream_pos in instance_to_max_stream_ordering_map.items() - if stream_pos > min_stream_pos - } - ), - ) - - # Since we fetched the users room list at some point in time after the from/to - # tokens, we need to revert/rewind some membership changes to match the point in - # time of the `to_token`. In particular, we need to make these fixups: - # - # - 1a) Remove rooms that the user joined after the `to_token` - # - 1b) Add back rooms that the user left after the `to_token` - # - 1c) Update room membership events to the point in time of the `to_token` - # - 2) Figure out which rooms are `newly_left` rooms (> `from_token` and <= `to_token`) - # - 3) Figure out which rooms are `newly_joined` (> `from_token` and <= `to_token`) - # - 4) Figure out which rooms are DM's - - # 1) Fetch membership changes that fall in the range from `to_token` up to - # `membership_snapshot_token` - # - # If our `to_token` is already the same or ahead of the latest room membership - # for the user, we don't need to do any "2)" fix-ups and can just straight-up - # use the room list from the snapshot as a base (nothing has changed) - current_state_delta_membership_changes_after_to_token = [] - if not membership_snapshot_token.is_before_or_eq(to_token.room_key): - current_state_delta_membership_changes_after_to_token = ( - await self.store.get_current_state_delta_membership_changes_for_user( - user_id, - from_key=to_token.room_key, - to_key=membership_snapshot_token, - excluded_room_ids=self.rooms_to_exclude_globally, - ) - ) - - # 1) Assemble a list of the first membership event after the `to_token` so we can - # step backward to the previous membership that would apply to the from/to - # range. - first_membership_change_by_room_id_after_to_token: Dict[ - str, CurrentStateDeltaMembership - ] = {} - for membership_change in current_state_delta_membership_changes_after_to_token: - # Only set if we haven't already set it - first_membership_change_by_room_id_after_to_token.setdefault( - membership_change.room_id, membership_change - ) - - # 1) Fixup - # - # Since we fetched a snapshot of the users room list at some point in time after - # the from/to tokens, we need to revert/rewind some membership changes to match - # the point in time of the `to_token`. - for ( - room_id, - first_membership_change_after_to_token, - ) in first_membership_change_by_room_id_after_to_token.items(): - # 1a) Remove rooms that the user joined after the `to_token` - if first_membership_change_after_to_token.prev_event_id is None: - sync_room_id_set.pop(room_id, None) - # 1b) 1c) From the first membership event after the `to_token`, step backward to the - # previous membership that would apply to the from/to range. - else: - # We don't expect these fields to be `None` if we have a `prev_event_id` - # but we're being defensive since it's possible that the prev event was - # culled from the database. - if ( - first_membership_change_after_to_token.prev_event_pos is not None - and first_membership_change_after_to_token.prev_membership - is not None - ): - sync_room_id_set[room_id] = _RoomMembershipForUser( - room_id=room_id, - event_id=first_membership_change_after_to_token.prev_event_id, - event_pos=first_membership_change_after_to_token.prev_event_pos, - membership=first_membership_change_after_to_token.prev_membership, - sender=first_membership_change_after_to_token.prev_sender, - # We will update these fields below to be accurate - newly_joined=False, - newly_left=False, - is_dm=False, - ) - else: - # If we can't find the previous membership event, we shouldn't - # include the room in the sync response since we can't determine the - # exact membership state and shouldn't rely on the current snapshot. - sync_room_id_set.pop(room_id, None) - - # 2) Fetch membership changes that fall in the range from `from_token` up to `to_token` - current_state_delta_membership_changes_in_from_to_range = [] - if from_token: - current_state_delta_membership_changes_in_from_to_range = ( - await self.store.get_current_state_delta_membership_changes_for_user( - user_id, - from_key=from_token.room_key, - to_key=to_token.room_key, - excluded_room_ids=self.rooms_to_exclude_globally, - ) - ) - - # 2) Assemble a list of the last membership events in some given ranges. Someone - # could have left and joined multiple times during the given range but we only - # care about end-result so we grab the last one. - last_membership_change_by_room_id_in_from_to_range: Dict[ - str, CurrentStateDeltaMembership - ] = {} - # We also want to assemble a list of the first membership events during the token - # range so we can step backward to the previous membership that would apply to - # before the token range to see if we have `newly_joined` the room. - first_membership_change_by_room_id_in_from_to_range: Dict[ - str, CurrentStateDeltaMembership - ] = {} - # Keep track if the room has a non-join event in the token range so we can later - # tell if it was a `newly_joined` room. If the last membership event in the - # token range is a join and there is also some non-join in the range, we know - # they `newly_joined`. - has_non_join_event_by_room_id_in_from_to_range: Dict[str, bool] = {} - for ( - membership_change - ) in current_state_delta_membership_changes_in_from_to_range: - room_id = membership_change.room_id - - last_membership_change_by_room_id_in_from_to_range[room_id] = ( - membership_change - ) - # Only set if we haven't already set it - first_membership_change_by_room_id_in_from_to_range.setdefault( - room_id, membership_change - ) - - if membership_change.membership != Membership.JOIN: - has_non_join_event_by_room_id_in_from_to_range[room_id] = True - - # 2) Fixup - # - # 3) We also want to assemble a list of possibly newly joined rooms. Someone - # could have left and joined multiple times during the given range but we only - # care about whether they are joined at the end of the token range so we are - # working with the last membership even in the token range. - possibly_newly_joined_room_ids = set() - for ( - last_membership_change_in_from_to_range - ) in last_membership_change_by_room_id_in_from_to_range.values(): - room_id = last_membership_change_in_from_to_range.room_id - - # 3) - if last_membership_change_in_from_to_range.membership == Membership.JOIN: - possibly_newly_joined_room_ids.add(room_id) - - # 2) Figure out newly_left rooms (> `from_token` and <= `to_token`). - if last_membership_change_in_from_to_range.membership == Membership.LEAVE: - # 2) Mark this room as `newly_left` - - # If we're seeing a membership change here, we should expect to already - # have it in our snapshot but if a state reset happens, it wouldn't have - # shown up in our snapshot but appear as a change here. - existing_sync_entry = sync_room_id_set.get(room_id) - if existing_sync_entry is not None: - # Normal expected case - sync_room_id_set[room_id] = existing_sync_entry.copy_and_replace( - newly_left=True - ) - else: - # State reset! - logger.warn( - "State reset detected for room_id %s with %s who is no longer in the room", - room_id, - user_id, - ) - # Even though a state reset happened which removed the person from - # the room, we still add it the list so the user knows they left the - # room. Downstream code can check for a state reset by looking for - # `event_id=None and membership is not None`. - sync_room_id_set[room_id] = _RoomMembershipForUser( - room_id=room_id, - event_id=last_membership_change_in_from_to_range.event_id, - event_pos=last_membership_change_in_from_to_range.event_pos, - membership=last_membership_change_in_from_to_range.membership, - sender=last_membership_change_in_from_to_range.sender, - newly_joined=False, - newly_left=True, - is_dm=False, - ) - - # 3) Figure out `newly_joined` - for room_id in possibly_newly_joined_room_ids: - has_non_join_in_from_to_range = ( - has_non_join_event_by_room_id_in_from_to_range.get(room_id, False) - ) - # If the last membership event in the token range is a join and there is - # also some non-join in the range, we know they `newly_joined`. - if has_non_join_in_from_to_range: - # We found a `newly_joined` room (we left and joined within the token range) - sync_room_id_set[room_id] = sync_room_id_set[room_id].copy_and_replace( - newly_joined=True - ) - else: - prev_event_id = first_membership_change_by_room_id_in_from_to_range[ - room_id - ].prev_event_id - prev_membership = first_membership_change_by_room_id_in_from_to_range[ - room_id - ].prev_membership - - if prev_event_id is None: - # We found a `newly_joined` room (we are joining the room for the - # first time within the token range) - sync_room_id_set[room_id] = sync_room_id_set[ - room_id - ].copy_and_replace(newly_joined=True) - # Last resort, we need to step back to the previous membership event - # just before the token range to see if we're joined then or not. - elif prev_membership != Membership.JOIN: - # We found a `newly_joined` room (we left before the token range - # and joined within the token range) - sync_room_id_set[room_id] = sync_room_id_set[ - room_id - ].copy_and_replace(newly_joined=True) - - # 4) Figure out which rooms the user considers to be direct-message (DM) rooms - # - # We're using global account data (`m.direct`) instead of checking for - # `is_direct` on membership events because that property only appears for - # the invitee membership event (doesn't show up for the inviter). - # - # We're unable to take `to_token` into account for global account data since - # we only keep track of the latest account data for the user. - dm_map = await self.store.get_global_account_data_by_type_for_user( - user_id, AccountDataTypes.DIRECT - ) - - # Flatten out the map. Account data is set by the client so it needs to be - # scrutinized. - dm_room_id_set = set() - if isinstance(dm_map, dict): - for room_ids in dm_map.values(): - # Account data should be a list of room IDs. Ignore anything else - if isinstance(room_ids, list): - for room_id in room_ids: - if isinstance(room_id, str): - dm_room_id_set.add(room_id) - - # 4) Fixup - for room_id in sync_room_id_set: - sync_room_id_set[room_id] = sync_room_id_set[room_id].copy_and_replace( - is_dm=room_id in dm_room_id_set - ) - - return sync_room_id_set - - @trace - async def filter_rooms_relevant_for_sync( - self, - user: UserID, - room_membership_for_user_map: Dict[str, _RoomMembershipForUser], - ) -> Dict[str, _RoomMembershipForUser]: - """ - Filter room IDs that should/can be listed for this user in the sync response (the - full room list that will be further filtered, sorted, and sliced). - - We're looking for rooms where the user has the following state in the token - range (> `from_token` and <= `to_token`): - - - `invite`, `join`, `knock`, `ban` membership events - - Kicks (`leave` membership events where `sender` is different from the - `user_id`/`state_key`) - - `newly_left` (rooms that were left during the given token range) - - In order for bans/kicks to not show up in sync, you need to `/forget` those - rooms. This doesn't modify the event itself though and only adds the - `forgotten` flag to the `room_memberships` table in Synapse. There isn't a way - to tell when a room was forgotten at the moment so we can't factor it into the - from/to range. - - Args: - user: User that is syncing - room_membership_for_user_map: Room membership for the user - - Returns: - A dictionary of room IDs that should be listed in the sync response along - with membership information in that room at the time of `to_token`. - """ - user_id = user.to_string() - - # Filter rooms to only what we're interested to sync with - filtered_sync_room_map = { - room_id: room_membership_for_user - for room_id, room_membership_for_user in room_membership_for_user_map.items() - if filter_membership_for_sync( - user_id=user_id, - room_membership_for_user=room_membership_for_user, - ) - } - - return filtered_sync_room_map - - async def check_room_subscription_allowed_for_user( - self, - room_id: str, - room_membership_for_user_map: Dict[str, _RoomMembershipForUser], - to_token: StreamToken, - ) -> Optional[_RoomMembershipForUser]: - """ - Check whether the user is allowed to see the room based on whether they have - ever had membership in the room or if the room is `world_readable`. - - Similar to `check_user_in_room_or_world_readable(...)` - - Args: - room_id: Room to check - room_membership_for_user_map: Room membership for the user at the time of - the `to_token` (<= `to_token`). - to_token: The token to fetch rooms up to. - - Returns: - The room membership for the user if they are allowed to subscribe to the - room else `None`. - """ - - # We can first check if they are already allowed to see the room based - # on our previous work to assemble the `room_membership_for_user_map`. - # - # If they have had any membership in the room over time (up to the `to_token`), - # let them subscribe and see what they can. - existing_membership_for_user = room_membership_for_user_map.get(room_id) - if existing_membership_for_user is not None: - return existing_membership_for_user - - # TODO: Handle `world_readable` rooms - return None - - # If the room is `world_readable`, it doesn't matter whether they can join, - # everyone can see the room. - # not_in_room_membership_for_user = _RoomMembershipForUser( - # room_id=room_id, - # event_id=None, - # event_pos=None, - # membership=None, - # sender=None, - # newly_joined=False, - # newly_left=False, - # is_dm=False, - # ) - # room_state = await self.get_current_state_at( - # room_id=room_id, - # room_membership_for_user_at_to_token=not_in_room_membership_for_user, - # state_filter=StateFilter.from_types( - # [(EventTypes.RoomHistoryVisibility, "")] - # ), - # to_token=to_token, - # ) - - # visibility_event = room_state.get((EventTypes.RoomHistoryVisibility, "")) - # if ( - # visibility_event is not None - # and visibility_event.content.get("history_visibility") - # == HistoryVisibility.WORLD_READABLE - # ): - # return not_in_room_membership_for_user - - # return None - - @trace - async def _bulk_get_stripped_state_for_rooms_from_sync_room_map( - self, - room_ids: StrCollection, - sync_room_map: Dict[str, _RoomMembershipForUser], - ) -> Dict[str, Optional[StateMap[StrippedStateEvent]]]: - """ - Fetch stripped state for a list of room IDs. Stripped state is only - applicable to invite/knock rooms. Other rooms will have `None` as their - stripped state. - - For invite rooms, we pull from `unsigned.invite_room_state`. - For knock rooms, we pull from `unsigned.knock_room_state`. - - Args: - room_ids: Room IDs to fetch stripped state for - sync_room_map: Dictionary of room IDs to sort along with membership - information in the room at the time of `to_token`. - - Returns: - Mapping from room_id to mapping of (type, state_key) to stripped state - event. - """ - room_id_to_stripped_state_map: Dict[ - str, Optional[StateMap[StrippedStateEvent]] - ] = {} - - # Fetch what we haven't before - room_ids_to_fetch = [ - room_id - for room_id in room_ids - if room_id not in room_id_to_stripped_state_map - ] - - # Gather a list of event IDs we can grab stripped state from - invite_or_knock_event_ids: List[str] = [] - for room_id in room_ids_to_fetch: - if sync_room_map[room_id].membership in ( - Membership.INVITE, - Membership.KNOCK, - ): - event_id = sync_room_map[room_id].event_id - # If this is an invite/knock then there should be an event_id - assert event_id is not None - invite_or_knock_event_ids.append(event_id) - else: - room_id_to_stripped_state_map[room_id] = None - - invite_or_knock_events = await self.store.get_events(invite_or_knock_event_ids) - for invite_or_knock_event in invite_or_knock_events.values(): - room_id = invite_or_knock_event.room_id - membership = invite_or_knock_event.membership - - raw_stripped_state_events = None - if membership == Membership.INVITE: - invite_room_state = invite_or_knock_event.unsigned.get( - "invite_room_state" - ) - raw_stripped_state_events = invite_room_state - elif membership == Membership.KNOCK: - knock_room_state = invite_or_knock_event.unsigned.get( - "knock_room_state" - ) - raw_stripped_state_events = knock_room_state - else: - raise AssertionError( - f"Unexpected membership {membership} (this is a problem with Synapse itself)" - ) - - stripped_state_map: Optional[MutableStateMap[StrippedStateEvent]] = None - # Scrutinize unsigned things. `raw_stripped_state_events` should be a list - # of stripped events - if raw_stripped_state_events is not None: - stripped_state_map = {} - if isinstance(raw_stripped_state_events, list): - for raw_stripped_event in raw_stripped_state_events: - stripped_state_event = parse_stripped_state_event( - raw_stripped_event - ) - if stripped_state_event is not None: - stripped_state_map[ - ( - stripped_state_event.type, - stripped_state_event.state_key, - ) - ] = stripped_state_event - - room_id_to_stripped_state_map[room_id] = stripped_state_map - - return room_id_to_stripped_state_map - - @trace - async def _bulk_get_partial_current_state_content_for_rooms( - self, - content_type: Literal[ - # `content.type` from `EventTypes.Create`` - "room_type", - # `content.algorithm` from `EventTypes.RoomEncryption` - "room_encryption", - ], - room_ids: Set[str], - sync_room_map: Dict[str, _RoomMembershipForUser], - to_token: StreamToken, - room_id_to_stripped_state_map: Dict[ - str, Optional[StateMap[StrippedStateEvent]] - ], - ) -> Mapping[str, Union[Optional[str], StateSentinel]]: - """ - Get the given state event content for a list of rooms. First we check the - current state of the room, then fallback to stripped state if available, then - historical state. - - Args: - content_type: Which content to grab - room_ids: Room IDs to fetch the given content field for. - sync_room_map: Dictionary of room IDs to sort along with membership - information in the room at the time of `to_token`. - to_token: We filter based on the state of the room at this token - room_id_to_stripped_state_map: This does not need to be filled in before - calling this function. Mapping from room_id to mapping of (type, state_key) - to stripped state event. Modified in place when we fetch new rooms so we can - save work next time this function is called. - - Returns: - A mapping from room ID to the state event content if the room has - the given state event (event_type, ""), otherwise `None`. Rooms unknown to - this server will return `ROOM_UNKNOWN_SENTINEL`. - """ - room_id_to_content: Dict[str, Union[Optional[str], StateSentinel]] = {} - - # As a bulk shortcut, use the current state if the server is particpating in the - # room (meaning we have current state). Ideally, for leave/ban rooms, we would - # want the state at the time of the membership instead of current state to not - # leak anything but we consider the create/encryption stripped state events to - # not be a secret given they are often set at the start of the room and they are - # normally handed out on invite/knock. - # - # Be mindful to only use this for non-sensitive details. For example, even - # though the room name/avatar/topic are also stripped state, they seem a lot - # more senstive to leak the current state value of. - # - # Since this function is cached, we need to make a mutable copy via - # `dict(...)`. - event_type = "" - event_content_field = "" - if content_type == "room_type": - event_type = EventTypes.Create - event_content_field = EventContentFields.ROOM_TYPE - room_id_to_content = dict(await self.store.bulk_get_room_type(room_ids)) - elif content_type == "room_encryption": - event_type = EventTypes.RoomEncryption - event_content_field = EventContentFields.ENCRYPTION_ALGORITHM - room_id_to_content = dict( - await self.store.bulk_get_room_encryption(room_ids) - ) - else: - assert_never(content_type) - - room_ids_with_results = [ - room_id - for room_id, content_field in room_id_to_content.items() - if content_field is not ROOM_UNKNOWN_SENTINEL - ] - - # We might not have current room state for remote invite/knocks if we are - # the first person on our server to see the room. The best we can do is look - # in the optional stripped state from the invite/knock event. - room_ids_without_results = room_ids.difference( - chain( - room_ids_with_results, - [ - room_id - for room_id, stripped_state_map in room_id_to_stripped_state_map.items() - if stripped_state_map is not None - ], - ) - ) - room_id_to_stripped_state_map.update( - await self._bulk_get_stripped_state_for_rooms_from_sync_room_map( - room_ids_without_results, sync_room_map - ) - ) - - # Update our `room_id_to_content` map based on the stripped state - # (applies to invite/knock rooms) - rooms_ids_without_stripped_state: Set[str] = set() - for room_id in room_ids_without_results: - stripped_state_map = room_id_to_stripped_state_map.get( - room_id, Sentinel.UNSET_SENTINEL - ) - assert stripped_state_map is not Sentinel.UNSET_SENTINEL, ( - f"Stripped state left unset for room {room_id}. " - + "Make sure you're calling `_bulk_get_stripped_state_for_rooms_from_sync_room_map(...)` " - + "with that room_id. (this is a problem with Synapse itself)" - ) - - # If there is some stripped state, we assume the remote server passed *all* - # of the potential stripped state events for the room. - if stripped_state_map is not None: - create_stripped_event = stripped_state_map.get((EventTypes.Create, "")) - stripped_event = stripped_state_map.get((event_type, "")) - # Sanity check that we at-least have the create event - if create_stripped_event is not None: - if stripped_event is not None: - room_id_to_content[room_id] = stripped_event.content.get( - event_content_field - ) - else: - # Didn't see the state event we're looking for in the stripped - # state so we can assume relevant content field is `None`. - room_id_to_content[room_id] = None - else: - rooms_ids_without_stripped_state.add(room_id) - - # Last resort, we might not have current room state for rooms that the - # server has left (no one local is in the room) but we can look at the - # historical state. - # - # Update our `room_id_to_content` map based on the state at the time of - # the membership event. - for room_id in rooms_ids_without_stripped_state: - # TODO: It would be nice to look this up in a bulk way (N+1 queries) - # - # TODO: `get_state_at(...)` doesn't take into account the "current state". - room_state = await self.storage_controllers.state.get_state_at( - room_id=room_id, - stream_position=to_token.copy_and_replace( - StreamKeyType.ROOM, - sync_room_map[room_id].event_pos.to_room_stream_token(), - ), - state_filter=StateFilter.from_types( - [ - (EventTypes.Create, ""), - (event_type, ""), - ] - ), - # Partially-stated rooms should have all state events except for - # remote membership events so we don't need to wait at all because - # we only want the create event and some non-member event. - await_full_state=False, - ) - # We can use the create event as a canary to tell whether the server has - # seen the room before - create_event = room_state.get((EventTypes.Create, "")) - state_event = room_state.get((event_type, "")) - - if create_event is None: - # Skip for unknown rooms - continue - - if state_event is not None: - room_id_to_content[room_id] = state_event.content.get( - event_content_field - ) - else: - # Didn't see the state event we're looking for in the stripped - # state so we can assume relevant content field is `None`. - room_id_to_content[room_id] = None - - return room_id_to_content - - @trace - async def filter_rooms( - self, - user: UserID, - sync_room_map: Dict[str, _RoomMembershipForUser], - filters: SlidingSyncConfig.SlidingSyncList.Filters, - to_token: StreamToken, - ) -> Dict[str, _RoomMembershipForUser]: - """ - Filter rooms based on the sync request. - - Args: - user: User to filter rooms for - sync_room_map: Dictionary of room IDs to sort along with membership - information in the room at the time of `to_token`. - filters: Filters to apply - to_token: We filter based on the state of the room at this token - - Returns: - A filtered dictionary of room IDs along with membership information in the - room at the time of `to_token`. - """ - room_id_to_stripped_state_map: Dict[ - str, Optional[StateMap[StrippedStateEvent]] - ] = {} - - filtered_room_id_set = set(sync_room_map.keys()) - - # Filter for Direct-Message (DM) rooms - if filters.is_dm is not None: - with start_active_span("filters.is_dm"): - if filters.is_dm: - # Only DM rooms please - filtered_room_id_set = { - room_id - for room_id in filtered_room_id_set - if sync_room_map[room_id].is_dm - } - else: - # Only non-DM rooms please - filtered_room_id_set = { - room_id - for room_id in filtered_room_id_set - if not sync_room_map[room_id].is_dm - } - - if filters.spaces is not None: - with start_active_span("filters.spaces"): - raise NotImplementedError() - - # Filter for encrypted rooms - if filters.is_encrypted is not None: - with start_active_span("filters.is_encrypted"): - room_id_to_encryption = ( - await self._bulk_get_partial_current_state_content_for_rooms( - content_type="room_encryption", - room_ids=filtered_room_id_set, - to_token=to_token, - sync_room_map=sync_room_map, - room_id_to_stripped_state_map=room_id_to_stripped_state_map, - ) - ) - - # Make a copy so we don't run into an error: `Set changed size during - # iteration`, when we filter out and remove items - for room_id in filtered_room_id_set.copy(): - encryption = room_id_to_encryption.get( - room_id, ROOM_UNKNOWN_SENTINEL - ) - - # Just remove rooms if we can't determine their encryption status - if encryption is ROOM_UNKNOWN_SENTINEL: - filtered_room_id_set.remove(room_id) - continue - - # If we're looking for encrypted rooms, filter out rooms that are not - # encrypted and vice versa - is_encrypted = encryption is not None - if (filters.is_encrypted and not is_encrypted) or ( - not filters.is_encrypted and is_encrypted - ): - filtered_room_id_set.remove(room_id) - - # Filter for rooms that the user has been invited to - if filters.is_invite is not None: - with start_active_span("filters.is_invite"): - # Make a copy so we don't run into an error: `Set changed size during - # iteration`, when we filter out and remove items - for room_id in filtered_room_id_set.copy(): - room_for_user = sync_room_map[room_id] - # If we're looking for invite rooms, filter out rooms that the user is - # not invited to and vice versa - if ( - filters.is_invite - and room_for_user.membership != Membership.INVITE - ) or ( - not filters.is_invite - and room_for_user.membership == Membership.INVITE - ): - filtered_room_id_set.remove(room_id) - - # Filter by room type (space vs room, etc). A room must match one of the types - # provided in the list. `None` is a valid type for rooms which do not have a - # room type. - if filters.room_types is not None or filters.not_room_types is not None: - with start_active_span("filters.room_types"): - room_id_to_type = ( - await self._bulk_get_partial_current_state_content_for_rooms( - content_type="room_type", - room_ids=filtered_room_id_set, - to_token=to_token, - sync_room_map=sync_room_map, - room_id_to_stripped_state_map=room_id_to_stripped_state_map, - ) - ) - - # Make a copy so we don't run into an error: `Set changed size during - # iteration`, when we filter out and remove items - for room_id in filtered_room_id_set.copy(): - room_type = room_id_to_type.get(room_id, ROOM_UNKNOWN_SENTINEL) - - # Just remove rooms if we can't determine their type - if room_type is ROOM_UNKNOWN_SENTINEL: - filtered_room_id_set.remove(room_id) - continue - - if ( - filters.room_types is not None - and room_type not in filters.room_types - ): - filtered_room_id_set.remove(room_id) - - if ( - filters.not_room_types is not None - and room_type in filters.not_room_types - ): - filtered_room_id_set.remove(room_id) - - if filters.room_name_like is not None: - with start_active_span("filters.room_name_like"): - # TODO: The room name is a bit more sensitive to leak than the - # create/encryption event. Maybe we should consider a better way to fetch - # historical state before implementing this. - # - # room_id_to_create_content = await self._bulk_get_partial_current_state_content_for_rooms( - # content_type="room_name", - # room_ids=filtered_room_id_set, - # to_token=to_token, - # sync_room_map=sync_room_map, - # room_id_to_stripped_state_map=room_id_to_stripped_state_map, - # ) - raise NotImplementedError() - - if filters.tags is not None or filters.not_tags is not None: - with start_active_span("filters.tags"): - raise NotImplementedError() - - # Assemble a new sync room map but only with the `filtered_room_id_set` - return {room_id: sync_room_map[room_id] for room_id in filtered_room_id_set} - - @trace - async def sort_rooms( - self, - sync_room_map: Dict[str, _RoomMembershipForUser], - to_token: StreamToken, - ) -> List[_RoomMembershipForUser]: - """ - Sort by `stream_ordering` of the last event that the user should see in the - room. `stream_ordering` is unique so we get a stable sort. - - Args: - sync_room_map: Dictionary of room IDs to sort along with membership - information in the room at the time of `to_token`. - to_token: We sort based on the events in the room at this token (<= `to_token`) - - Returns: - A sorted list of room IDs by `stream_ordering` along with membership information. - """ - - # Assemble a map of room ID to the `stream_ordering` of the last activity that the - # user should see in the room (<= `to_token`) - last_activity_in_room_map: Dict[str, int] = {} - - for room_id, room_for_user in sync_room_map.items(): - if room_for_user.membership != Membership.JOIN: - # If the user has left/been invited/knocked/been banned from a - # room, they shouldn't see anything past that point. - # - # FIXME: It's possible that people should see beyond this point - # in invited/knocked cases if for example the room has - # `invite`/`world_readable` history visibility, see - # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 - last_activity_in_room_map[room_id] = room_for_user.event_pos.stream - - # For fully-joined rooms, we find the latest activity at/before the - # `to_token`. - joined_room_positions = ( - await self.store.bulk_get_last_event_pos_in_room_before_stream_ordering( - [ - room_id - for room_id, room_for_user in sync_room_map.items() - if room_for_user.membership == Membership.JOIN - ], - to_token.room_key, - ) - ) - - last_activity_in_room_map.update(joined_room_positions) - - return sorted( - sync_room_map.values(), - # Sort by the last activity (stream_ordering) in the room - key=lambda room_info: last_activity_in_room_map[room_info.room_id], - # We want descending order - reverse=True, - ) - - @trace - async def get_current_state_ids_at( - self, - room_id: str, - room_membership_for_user_at_to_token: _RoomMembershipForUser, - state_filter: StateFilter, - to_token: StreamToken, - ) -> StateMap[str]: - """ - Get current state IDs for the user in the room according to their membership. This - will be the current state at the time of their LEAVE/BAN, otherwise will be the - current state <= to_token. - - Args: - room_id: The room ID to fetch data for - room_membership_for_user_at_token: Membership information for the user - in the room at the time of `to_token`. - to_token: The point in the stream to sync up to. - """ - state_ids: StateMap[str] - # People shouldn't see past their leave/ban event - if room_membership_for_user_at_to_token.membership in ( - Membership.LEAVE, - Membership.BAN, - ): - # TODO: `get_state_ids_at(...)` doesn't take into account the "current - # state". Maybe we need to use - # `get_forward_extremities_for_room_at_stream_ordering(...)` to "Fetch the - # current state at the time." - state_ids = await self.storage_controllers.state.get_state_ids_at( - room_id, - stream_position=to_token.copy_and_replace( - StreamKeyType.ROOM, - room_membership_for_user_at_to_token.event_pos.to_room_stream_token(), - ), - state_filter=state_filter, - # Partially-stated rooms should have all state events except for - # remote membership events. Since we've already excluded - # partially-stated rooms unless `required_state` only has - # `["m.room.member", "$LAZY"]` for membership, we should be able to - # retrieve everything requested. When we're lazy-loading, if there - # are some remote senders in the timeline, we should also have their - # membership event because we had to auth that timeline event. Plus - # we don't want to block the whole sync waiting for this one room. - await_full_state=False, - ) - # Otherwise, we can get the latest current state in the room - else: - state_ids = await self.storage_controllers.state.get_current_state_ids( - room_id, - state_filter, - # Partially-stated rooms should have all state events except for - # remote membership events. Since we've already excluded - # partially-stated rooms unless `required_state` only has - # `["m.room.member", "$LAZY"]` for membership, we should be able to - # retrieve everything requested. When we're lazy-loading, if there - # are some remote senders in the timeline, we should also have their - # membership event because we had to auth that timeline event. Plus - # we don't want to block the whole sync waiting for this one room. - await_full_state=False, - ) - # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token` - - return state_ids - - @trace - async def get_current_state_at( - self, - room_id: str, - room_membership_for_user_at_to_token: _RoomMembershipForUser, - state_filter: StateFilter, - to_token: StreamToken, - ) -> StateMap[EventBase]: - """ - Get current state for the user in the room according to their membership. This - will be the current state at the time of their LEAVE/BAN, otherwise will be the - current state <= to_token. - - Args: - room_id: The room ID to fetch data for - room_membership_for_user_at_token: Membership information for the user - in the room at the time of `to_token`. - to_token: The point in the stream to sync up to. - """ - state_ids = await self.get_current_state_ids_at( - room_id=room_id, - room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, - state_filter=state_filter, - to_token=to_token, - ) - - event_map = await self.store.get_events(list(state_ids.values())) - - state_map = {} - for key, event_id in state_ids.items(): - event = event_map.get(event_id) - if event: - state_map[key] = event - - return state_map - - async def get_room_sync_data( - self, - sync_config: SlidingSyncConfig, - room_id: str, - room_sync_config: RoomSyncConfig, - room_membership_for_user_at_to_token: _RoomMembershipForUser, - from_token: Optional[SlidingSyncStreamToken], - to_token: StreamToken, - ) -> SlidingSyncResult.RoomResult: - """ - Fetch room data for the sync response. - - We fetch data according to the token range (> `from_token` and <= `to_token`). - - Args: - user: User to fetch data for - room_id: The room ID to fetch data for - room_sync_config: Config for what data we should fetch for a room in the - sync response. - room_membership_for_user_at_to_token: Membership information for the user - in the room at the time of `to_token`. - from_token: The point in the stream to sync from. - to_token: The point in the stream to sync up to. - """ - user = sync_config.user - - set_tag( - SynapseTags.FUNC_ARG_PREFIX + "membership", - room_membership_for_user_at_to_token.membership, - ) - set_tag( - SynapseTags.FUNC_ARG_PREFIX + "timeline_limit", - room_sync_config.timeline_limit, - ) - - # Determine whether we should limit the timeline to the token range. - # - # We should return historical messages (before token range) in the - # following cases because we want clients to be able to show a basic - # screen of information: - # - # - Initial sync (because no `from_token` to limit us anyway) - # - When users `newly_joined` - # - For an incremental sync where we haven't sent it down this - # connection before - # - # Relevant spec issue: https://github.com/matrix-org/matrix-spec/issues/1917 - from_bound = None - initial = True - if from_token and not room_membership_for_user_at_to_token.newly_joined: - room_status = await self.connection_store.have_sent_room( - sync_config=sync_config, - connection_token=from_token.connection_position, - room_id=room_id, - ) - if room_status.status == HaveSentRoomFlag.LIVE: - from_bound = from_token.stream_token.room_key - initial = False - elif room_status.status == HaveSentRoomFlag.PREVIOUSLY: - assert room_status.last_token is not None - from_bound = room_status.last_token - initial = False - elif room_status.status == HaveSentRoomFlag.NEVER: - from_bound = None - initial = True - else: - assert_never(room_status.status) - - log_kv({"sliding_sync.room_status": room_status}) - - log_kv({"sliding_sync.from_bound": from_bound, "sliding_sync.initial": initial}) - - # Assemble the list of timeline events - # - # FIXME: It would be nice to make the `rooms` response more uniform regardless of - # membership. Currently, we have to make all of these optional because - # `invite`/`knock` rooms only have `stripped_state`. See - # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 - timeline_events: List[EventBase] = [] - bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None - limited: Optional[bool] = None - prev_batch_token: Optional[StreamToken] = None - num_live: Optional[int] = None - if ( - room_sync_config.timeline_limit > 0 - # No timeline for invite/knock rooms (just `stripped_state`) - and room_membership_for_user_at_to_token.membership - not in (Membership.INVITE, Membership.KNOCK) - ): - limited = False - # We want to start off using the `to_token` (vs `from_token`) because we look - # backwards from the `to_token` up to the `timeline_limit` and we might not - # reach the `from_token` before we hit the limit. We will update the room stream - # position once we've fetched the events to point to the earliest event fetched. - prev_batch_token = to_token - - # We're going to paginate backwards from the `to_token` - to_bound = to_token.room_key - # People shouldn't see past their leave/ban event - if room_membership_for_user_at_to_token.membership in ( - Membership.LEAVE, - Membership.BAN, - ): - to_bound = ( - room_membership_for_user_at_to_token.event_pos.to_room_stream_token() - ) - - # For initial `/sync` (and other historical scenarios mentioned above), we - # want to view a historical section of the timeline; to fetch events by - # `topological_ordering` (best representation of the room DAG as others were - # seeing it at the time). This also aligns with the order that `/messages` - # returns events in. - # - # For incremental `/sync`, we want to get all updates for rooms since - # the last `/sync` (regardless if those updates arrived late or happened - # a while ago in the past); to fetch events by `stream_ordering` (in the - # order they were received by the server). - # - # Relevant spec issue: https://github.com/matrix-org/matrix-spec/issues/1917 - # - # FIXME: Using workaround for mypy, - # https://github.com/python/mypy/issues/10740#issuecomment-1997047277 and - # https://github.com/python/mypy/issues/17479 - paginate_room_events_by_topological_ordering: PaginateFunction = ( - self.store.paginate_room_events_by_topological_ordering - ) - paginate_room_events_by_stream_ordering: PaginateFunction = ( - self.store.paginate_room_events_by_stream_ordering - ) - pagination_method: PaginateFunction = ( - # Use `topographical_ordering` for historical events - paginate_room_events_by_topological_ordering - if from_bound is None - # Use `stream_ordering` for updates - else paginate_room_events_by_stream_ordering - ) - timeline_events, new_room_key = await pagination_method( - room_id=room_id, - # The bounds are reversed so we can paginate backwards - # (from newer to older events) starting at to_bound. - # This ensures we fill the `limit` with the newest events first, - from_key=to_bound, - to_key=from_bound, - direction=Direction.BACKWARDS, - # We add one so we can determine if there are enough events to saturate - # the limit or not (see `limited`) - limit=room_sync_config.timeline_limit + 1, - ) - - # We want to return the events in ascending order (the last event is the - # most recent). - timeline_events.reverse() - - # Determine our `limited` status based on the timeline. We do this before - # filtering the events so we can accurately determine if there is more to - # paginate even if we filter out some/all events. - if len(timeline_events) > room_sync_config.timeline_limit: - limited = True - # Get rid of that extra "+ 1" event because we only used it to determine - # if we hit the limit or not - timeline_events = timeline_events[-room_sync_config.timeline_limit :] - assert timeline_events[0].internal_metadata.stream_ordering - new_room_key = RoomStreamToken( - stream=timeline_events[0].internal_metadata.stream_ordering - 1 - ) - - # Make sure we don't expose any events that the client shouldn't see - timeline_events = await filter_events_for_client( - self.storage_controllers, - user.to_string(), - timeline_events, - is_peeking=room_membership_for_user_at_to_token.membership - != Membership.JOIN, - filter_send_to_client=True, - ) - # TODO: Filter out `EventTypes.CallInvite` in public rooms, - # see https://github.com/element-hq/synapse/issues/17359 - - # TODO: Handle timeline gaps (`get_timeline_gaps()`) - - # Determine how many "live" events we have (events within the given token range). - # - # This is mostly useful to determine whether a given @mention event should - # make a noise or not. Clients cannot rely solely on the absence of - # `initial: true` to determine live events because if a room not in the - # sliding window bumps into the window because of an @mention it will have - # `initial: true` yet contain a single live event (with potentially other - # old events in the timeline) - num_live = 0 - if from_token is not None: - for timeline_event in reversed(timeline_events): - # This fields should be present for all persisted events - assert timeline_event.internal_metadata.stream_ordering is not None - assert timeline_event.internal_metadata.instance_name is not None - - persisted_position = PersistedEventPosition( - instance_name=timeline_event.internal_metadata.instance_name, - stream=timeline_event.internal_metadata.stream_ordering, - ) - if persisted_position.persisted_after( - from_token.stream_token.room_key - ): - num_live += 1 - else: - # Since we're iterating over the timeline events in - # reverse-chronological order, we can break once we hit an event - # that's not live. In the future, we could potentially optimize - # this more with a binary search (bisect). - break - - # If the timeline is `limited=True`, the client does not have all events - # necessary to calculate aggregations themselves. - if limited: - bundled_aggregations = ( - await self.relations_handler.get_bundled_aggregations( - timeline_events, user.to_string() - ) - ) - - # Update the `prev_batch_token` to point to the position that allows us to - # keep paginating backwards from the oldest event we return in the timeline. - prev_batch_token = prev_batch_token.copy_and_replace( - StreamKeyType.ROOM, new_room_key - ) - - # Figure out any stripped state events for invite/knocks. This allows the - # potential joiner to identify the room. - stripped_state: List[JsonDict] = [] - if room_membership_for_user_at_to_token.membership in ( - Membership.INVITE, - Membership.KNOCK, - ): - # This should never happen. If someone is invited/knocked on room, then - # there should be an event for it. - assert room_membership_for_user_at_to_token.event_id is not None - - invite_or_knock_event = await self.store.get_event( - room_membership_for_user_at_to_token.event_id - ) - - stripped_state = [] - if invite_or_knock_event.membership == Membership.INVITE: - stripped_state.extend( - invite_or_knock_event.unsigned.get("invite_room_state", []) - ) - elif invite_or_knock_event.membership == Membership.KNOCK: - stripped_state.extend( - invite_or_knock_event.unsigned.get("knock_room_state", []) - ) - - stripped_state.append(strip_event(invite_or_knock_event)) - - # TODO: Handle state resets. For example, if we see - # `room_membership_for_user_at_to_token.event_id=None and - # room_membership_for_user_at_to_token.membership is not None`, we should - # indicate to the client that a state reset happened. Perhaps we should indicate - # this by setting `initial: True` and empty `required_state`. - - # Check whether the room has a name set - name_state_ids = await self.get_current_state_ids_at( - room_id=room_id, - room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, - state_filter=StateFilter.from_types([(EventTypes.Name, "")]), - to_token=to_token, - ) - name_event_id = name_state_ids.get((EventTypes.Name, "")) - - room_membership_summary: Mapping[str, MemberSummary] - empty_membership_summary = MemberSummary([], 0) - if room_membership_for_user_at_to_token.membership in ( - Membership.LEAVE, - Membership.BAN, - ): - # TODO: Figure out how to get the membership summary for left/banned rooms - room_membership_summary = {} - else: - room_membership_summary = await self.store.get_room_summary(room_id) - # TODO: Reverse/rewind back to the `to_token` - - # `heroes` are required if the room name is not set. - # - # Note: When you're the first one on your server to be invited to a new room - # over federation, we only have access to some stripped state in - # `event.unsigned.invite_room_state` which currently doesn't include `heroes`, - # see https://github.com/matrix-org/matrix-spec/issues/380. This means that - # clients won't be able to calculate the room name when necessary and just a - # pitfall we have to deal with until that spec issue is resolved. - hero_user_ids: List[str] = [] - # TODO: Should we also check for `EventTypes.CanonicalAlias` - # (`m.room.canonical_alias`) as a fallback for the room name? see - # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1671260153 - if name_event_id is None: - hero_user_ids = extract_heroes_from_room_summary( - room_membership_summary, me=user.to_string() - ) - - # Fetch the `required_state` for the room - # - # No `required_state` for invite/knock rooms (just `stripped_state`) - # - # FIXME: It would be nice to make the `rooms` response more uniform regardless - # of membership. Currently, we have to make this optional because - # `invite`/`knock` rooms only have `stripped_state`. See - # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 - # - # Calculate the `StateFilter` based on the `required_state` for the room - required_state_filter = StateFilter.none() - if room_membership_for_user_at_to_token.membership not in ( - Membership.INVITE, - Membership.KNOCK, - ): - # If we have a double wildcard ("*", "*") in the `required_state`, we need - # to fetch all state for the room - # - # Note: MSC3575 describes different behavior to how we're handling things - # here but since it's not wrong to return more state than requested - # (`required_state` is just the minimum requested), it doesn't matter if we - # include more than client wanted. This complexity is also under scrutiny, - # see - # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1185109050 - # - # > One unique exception is when you request all state events via ["*", "*"]. When used, - # > all state events are returned by default, and additional entries FILTER OUT the returned set - # > of state events. These additional entries cannot use '*' themselves. - # > For example, ["*", "*"], ["m.room.member", "@alice:example.com"] will _exclude_ every m.room.member - # > event _except_ for @alice:example.com, and include every other state event. - # > In addition, ["*", "*"], ["m.space.child", "*"] is an error, the m.space.child filter is not - # > required as it would have been returned anyway. - # > - # > -- MSC3575 (https://github.com/matrix-org/matrix-spec-proposals/pull/3575) - if StateValues.WILDCARD in room_sync_config.required_state_map.get( - StateValues.WILDCARD, set() - ): - set_tag( - SynapseTags.FUNC_ARG_PREFIX + "required_state_wildcard", - True, - ) - required_state_filter = StateFilter.all() - # TODO: `StateFilter` currently doesn't support wildcard event types. We're - # currently working around this by returning all state to the client but it - # would be nice to fetch less from the database and return just what the - # client wanted. - elif ( - room_sync_config.required_state_map.get(StateValues.WILDCARD) - is not None - ): - set_tag( - SynapseTags.FUNC_ARG_PREFIX + "required_state_wildcard_event_type", - True, - ) - required_state_filter = StateFilter.all() - else: - required_state_types: List[Tuple[str, Optional[str]]] = [] - for ( - state_type, - state_key_set, - ) in room_sync_config.required_state_map.items(): - num_wild_state_keys = 0 - lazy_load_room_members = False - num_others = 0 - for state_key in state_key_set: - if state_key == StateValues.WILDCARD: - num_wild_state_keys += 1 - # `None` is a wildcard in the `StateFilter` - required_state_types.append((state_type, None)) - # We need to fetch all relevant people when we're lazy-loading membership - elif ( - state_type == EventTypes.Member - and state_key == StateValues.LAZY - ): - lazy_load_room_members = True - # Everyone in the timeline is relevant - timeline_membership: Set[str] = set() - if timeline_events is not None: - for timeline_event in timeline_events: - timeline_membership.add(timeline_event.sender) - - for user_id in timeline_membership: - required_state_types.append( - (EventTypes.Member, user_id) - ) - - # FIXME: We probably also care about invite, ban, kick, targets, etc - # but the spec only mentions "senders". - elif state_key == StateValues.ME: - num_others += 1 - required_state_types.append((state_type, user.to_string())) - else: - num_others += 1 - required_state_types.append((state_type, state_key)) - - set_tag( - SynapseTags.FUNC_ARG_PREFIX - + "required_state_wildcard_state_key_count", - num_wild_state_keys, - ) - set_tag( - SynapseTags.FUNC_ARG_PREFIX + "required_state_lazy", - lazy_load_room_members, - ) - set_tag( - SynapseTags.FUNC_ARG_PREFIX + "required_state_other_count", - num_others, - ) - - required_state_filter = StateFilter.from_types(required_state_types) - - # We need this base set of info for the response so let's just fetch it along - # with the `required_state` for the room - meta_room_state = [(EventTypes.Name, ""), (EventTypes.RoomAvatar, "")] + [ - (EventTypes.Member, hero_user_id) for hero_user_id in hero_user_ids - ] - state_filter = StateFilter.all() - if required_state_filter != StateFilter.all(): - state_filter = StateFilter( - types=StateFilter.from_types( - chain(meta_room_state, required_state_filter.to_types()) - ).types, - include_others=required_state_filter.include_others, - ) - - # We can return all of the state that was requested if this was the first - # time we've sent the room down this connection. - room_state: StateMap[EventBase] = {} - if initial: - room_state = await self.get_current_state_at( - room_id=room_id, - room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, - state_filter=state_filter, - to_token=to_token, - ) - else: - assert from_bound is not None - - # TODO: Limit the number of state events we're about to send down - # the room, if its too many we should change this to an - # `initial=True`? - deltas = await self.store.get_current_state_deltas_for_room( - room_id=room_id, - from_token=from_bound, - to_token=to_token.room_key, - ) - # TODO: Filter room state before fetching events - # TODO: Handle state resets where event_id is None - events = await self.store.get_events( - [d.event_id for d in deltas if d.event_id] - ) - room_state = {(s.type, s.state_key): s for s in events.values()} - - required_room_state: StateMap[EventBase] = {} - if required_state_filter != StateFilter.none(): - required_room_state = required_state_filter.filter_state(room_state) - - # Find the room name and avatar from the state - room_name: Optional[str] = None - # TODO: Should we also check for `EventTypes.CanonicalAlias` - # (`m.room.canonical_alias`) as a fallback for the room name? see - # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1671260153 - name_event = room_state.get((EventTypes.Name, "")) - if name_event is not None: - room_name = name_event.content.get("name") - - room_avatar: Optional[str] = None - avatar_event = room_state.get((EventTypes.RoomAvatar, "")) - if avatar_event is not None: - room_avatar = avatar_event.content.get("url") - - # Assemble heroes: extract the info from the state we just fetched - heroes: List[SlidingSyncResult.RoomResult.StrippedHero] = [] - for hero_user_id in hero_user_ids: - member_event = room_state.get((EventTypes.Member, hero_user_id)) - if member_event is not None: - heroes.append( - SlidingSyncResult.RoomResult.StrippedHero( - user_id=hero_user_id, - display_name=member_event.content.get("displayname"), - avatar_url=member_event.content.get("avatar_url"), - ) - ) - - # Figure out the last bump event in the room - last_bump_event_result = ( - await self.store.get_last_event_pos_in_room_before_stream_ordering( - room_id, to_token.room_key, event_types=DEFAULT_BUMP_EVENT_TYPES - ) - ) - - # By default, just choose the membership event position - bump_stamp = room_membership_for_user_at_to_token.event_pos.stream - # But if we found a bump event, use that instead - if last_bump_event_result is not None: - _, new_bump_event_pos = last_bump_event_result - - # If we've just joined a remote room, then the last bump event may - # have been backfilled (and so have a negative stream ordering). - # These negative stream orderings can't sensibly be compared, so - # instead we use the membership event position. - if new_bump_event_pos.stream > 0: - bump_stamp = new_bump_event_pos.stream - - set_tag(SynapseTags.RESULT_PREFIX + "initial", initial) - - return SlidingSyncResult.RoomResult( - name=room_name, - avatar=room_avatar, - heroes=heroes, - is_dm=room_membership_for_user_at_to_token.is_dm, - initial=initial, - required_state=list(required_room_state.values()), - timeline_events=timeline_events, - bundled_aggregations=bundled_aggregations, - stripped_state=stripped_state, - prev_batch=prev_batch_token, - limited=limited, - num_live=num_live, - bump_stamp=bump_stamp, - joined_count=room_membership_summary.get( - Membership.JOIN, empty_membership_summary - ).count, - invited_count=room_membership_summary.get( - Membership.INVITE, empty_membership_summary - ).count, - # TODO: These are just dummy values. We could potentially just remove these - # since notifications can only really be done correctly on the client anyway - # (encrypted rooms). - notification_count=0, - highlight_count=0, - ) - - @trace - async def get_extensions_response( - self, - sync_config: SlidingSyncConfig, - actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList], - actual_room_ids: Set[str], - actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult], - to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken], - ) -> SlidingSyncResult.Extensions: - """Handle extension requests. - - Args: - sync_config: Sync configuration - actual_lists: Sliding window API. A map of list key to list results in the - Sliding Sync response. - actual_room_ids: The actual room IDs in the the Sliding Sync response. - actual_room_response_map: A map of room ID to room results in the the - Sliding Sync response. - to_token: The point in the stream to sync up to. - from_token: The point in the stream to sync from. - """ - - if sync_config.extensions is None: - return SlidingSyncResult.Extensions() - - to_device_response = None - if sync_config.extensions.to_device is not None: - to_device_response = await self.get_to_device_extension_response( - sync_config=sync_config, - to_device_request=sync_config.extensions.to_device, - to_token=to_token, - ) - - e2ee_response = None - if sync_config.extensions.e2ee is not None: - e2ee_response = await self.get_e2ee_extension_response( - sync_config=sync_config, - e2ee_request=sync_config.extensions.e2ee, - to_token=to_token, - from_token=from_token, - ) - - account_data_response = None - if sync_config.extensions.account_data is not None: - account_data_response = await self.get_account_data_extension_response( - sync_config=sync_config, - actual_lists=actual_lists, - actual_room_ids=actual_room_ids, - account_data_request=sync_config.extensions.account_data, - to_token=to_token, - from_token=from_token, - ) - - receipts_response = None - if sync_config.extensions.receipts is not None: - receipts_response = await self.get_receipts_extension_response( - sync_config=sync_config, - actual_lists=actual_lists, - actual_room_ids=actual_room_ids, - actual_room_response_map=actual_room_response_map, - receipts_request=sync_config.extensions.receipts, - to_token=to_token, - from_token=from_token, - ) - - typing_response = None - if sync_config.extensions.typing is not None: - typing_response = await self.get_typing_extension_response( - sync_config=sync_config, - actual_lists=actual_lists, - actual_room_ids=actual_room_ids, - actual_room_response_map=actual_room_response_map, - typing_request=sync_config.extensions.typing, - to_token=to_token, - from_token=from_token, - ) - - return SlidingSyncResult.Extensions( - to_device=to_device_response, - e2ee=e2ee_response, - account_data=account_data_response, - receipts=receipts_response, - typing=typing_response, - ) - - def find_relevant_room_ids_for_extension( - self, - requested_lists: Optional[List[str]], - requested_room_ids: Optional[List[str]], - actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList], - actual_room_ids: Set[str], - ) -> Set[str]: - """ - Handle the reserved `lists`/`rooms` keys for extensions. Extensions should only - return results for rooms in the Sliding Sync response. This matches up the - requested rooms/lists with the actual lists/rooms in the Sliding Sync response. - - {"lists": []} // Do not process any lists. - {"lists": ["rooms", "dms"]} // Process only a subset of lists. - {"lists": ["*"]} // Process all lists defined in the Sliding Window API. (This is the default.) - - {"rooms": []} // Do not process any specific rooms. - {"rooms": ["!a:b", "!c:d"]} // Process only a subset of room subscriptions. - {"rooms": ["*"]} // Process all room subscriptions defined in the Room Subscription API. (This is the default.) - - Args: - requested_lists: The `lists` from the extension request. - requested_room_ids: The `rooms` from the extension request. - actual_lists: The actual lists from the Sliding Sync response. - actual_room_ids: The actual room subscriptions from the Sliding Sync request. - """ - - # We only want to include account data for rooms that are already in the sliding - # sync response AND that were requested in the account data request. - relevant_room_ids: Set[str] = set() - - # See what rooms from the room subscriptions we should get account data for - if requested_room_ids is not None: - for room_id in requested_room_ids: - # A wildcard means we process all rooms from the room subscriptions - if room_id == "*": - relevant_room_ids.update(actual_room_ids) - break - - if room_id in actual_room_ids: - relevant_room_ids.add(room_id) - - # See what rooms from the sliding window lists we should get account data for - if requested_lists is not None: - for list_key in requested_lists: - # Just some typing because we share the variable name in multiple places - actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None - - # A wildcard means we process rooms from all lists - if list_key == "*": - for actual_list in actual_lists.values(): - # We only expect a single SYNC operation for any list - assert len(actual_list.ops) == 1 - sync_op = actual_list.ops[0] - assert sync_op.op == OperationType.SYNC - - relevant_room_ids.update(sync_op.room_ids) - - break - - actual_list = actual_lists.get(list_key) - if actual_list is not None: - # We only expect a single SYNC operation for any list - assert len(actual_list.ops) == 1 - sync_op = actual_list.ops[0] - assert sync_op.op == OperationType.SYNC - - relevant_room_ids.update(sync_op.room_ids) - - return relevant_room_ids - - @trace - async def get_to_device_extension_response( - self, - sync_config: SlidingSyncConfig, - to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension, - to_token: StreamToken, - ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]: - """Handle to-device extension (MSC3885) - - Args: - sync_config: Sync configuration - to_device_request: The to-device extension from the request - to_token: The point in the stream to sync up to. - """ - user_id = sync_config.user.to_string() - device_id = sync_config.requester.device_id - - # Skip if the extension is not enabled - if not to_device_request.enabled: - return None - - # Check that this request has a valid device ID (not all requests have - # to belong to a device, and so device_id is None) - if device_id is None: - return SlidingSyncResult.Extensions.ToDeviceExtension( - next_batch=f"{to_token.to_device_key}", - events=[], - ) - - since_stream_id = 0 - if to_device_request.since is not None: - # We've already validated this is an int. - since_stream_id = int(to_device_request.since) - - if to_token.to_device_key < since_stream_id: - # The since token is ahead of our current token, so we return an - # empty response. - logger.warning( - "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r", - since_stream_id, - to_token.to_device_key, - ) - return SlidingSyncResult.Extensions.ToDeviceExtension( - next_batch=to_device_request.since, - events=[], - ) - - # Delete everything before the given since token, as we know the - # device must have received them. - deleted = await self.store.delete_messages_for_device( - user_id=user_id, - device_id=device_id, - up_to_stream_id=since_stream_id, - ) - - logger.debug( - "Deleted %d to-device messages up to %d for %s", - deleted, - since_stream_id, - user_id, - ) - - messages, stream_id = await self.store.get_messages_for_device( - user_id=user_id, - device_id=device_id, - from_stream_id=since_stream_id, - to_stream_id=to_token.to_device_key, - limit=min(to_device_request.limit, 100), # Limit to at most 100 events - ) - - return SlidingSyncResult.Extensions.ToDeviceExtension( - next_batch=f"{stream_id}", - events=messages, - ) - - @trace - async def get_e2ee_extension_response( - self, - sync_config: SlidingSyncConfig, - e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension, - to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken], - ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]: - """Handle E2EE device extension (MSC3884) - - Args: - sync_config: Sync configuration - e2ee_request: The e2ee extension from the request - to_token: The point in the stream to sync up to. - from_token: The point in the stream to sync from. - """ - user_id = sync_config.user.to_string() - device_id = sync_config.requester.device_id - - # Skip if the extension is not enabled - if not e2ee_request.enabled: - return None - - device_list_updates: Optional[DeviceListUpdates] = None - if from_token is not None: - # TODO: This should take into account the `from_token` and `to_token` - device_list_updates = await self.device_handler.get_user_ids_changed( - user_id=user_id, - from_token=from_token.stream_token, - ) - - device_one_time_keys_count: Mapping[str, int] = {} - device_unused_fallback_key_types: Sequence[str] = [] - if device_id: - # TODO: We should have a way to let clients differentiate between the states of: - # * no change in OTK count since the provided since token - # * the server has zero OTKs left for this device - # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 - device_one_time_keys_count = await self.store.count_e2e_one_time_keys( - user_id, device_id - ) - device_unused_fallback_key_types = ( - await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) - ) - - return SlidingSyncResult.Extensions.E2eeExtension( - device_list_updates=device_list_updates, - device_one_time_keys_count=device_one_time_keys_count, - device_unused_fallback_key_types=device_unused_fallback_key_types, - ) - - @trace - async def get_account_data_extension_response( - self, - sync_config: SlidingSyncConfig, - actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList], - actual_room_ids: Set[str], - account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension, - to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken], - ) -> Optional[SlidingSyncResult.Extensions.AccountDataExtension]: - """Handle Account Data extension (MSC3959) - - Args: - sync_config: Sync configuration - actual_lists: Sliding window API. A map of list key to list results in the - Sliding Sync response. - actual_room_ids: The actual room IDs in the the Sliding Sync response. - account_data_request: The account_data extension from the request - to_token: The point in the stream to sync up to. - from_token: The point in the stream to sync from. - """ - user_id = sync_config.user.to_string() - - # Skip if the extension is not enabled - if not account_data_request.enabled: - return None - - global_account_data_map: Mapping[str, JsonMapping] = {} - if from_token is not None: - # TODO: This should take into account the `from_token` and `to_token` - global_account_data_map = ( - await self.store.get_updated_global_account_data_for_user( - user_id, from_token.stream_token.account_data_key - ) - ) - - have_push_rules_changed = await self.store.have_push_rules_changed_for_user( - user_id, from_token.stream_token.push_rules_key - ) - if have_push_rules_changed: - global_account_data_map = dict(global_account_data_map) - # TODO: This should take into account the `from_token` and `to_token` - global_account_data_map[AccountDataTypes.PUSH_RULES] = ( - await self.push_rules_handler.push_rules_for_user(sync_config.user) - ) - else: - # TODO: This should take into account the `to_token` - all_global_account_data = await self.store.get_global_account_data_for_user( - user_id - ) - - global_account_data_map = dict(all_global_account_data) - # TODO: This should take into account the `to_token` - global_account_data_map[AccountDataTypes.PUSH_RULES] = ( - await self.push_rules_handler.push_rules_for_user(sync_config.user) - ) - - # Fetch room account data - account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {} - relevant_room_ids = self.find_relevant_room_ids_for_extension( - requested_lists=account_data_request.lists, - requested_room_ids=account_data_request.rooms, - actual_lists=actual_lists, - actual_room_ids=actual_room_ids, - ) - if len(relevant_room_ids) > 0: - if from_token is not None: - # TODO: This should take into account the `from_token` and `to_token` - account_data_by_room_map = ( - await self.store.get_updated_room_account_data_for_user( - user_id, from_token.stream_token.account_data_key - ) - ) - else: - # TODO: This should take into account the `to_token` - account_data_by_room_map = ( - await self.store.get_room_account_data_for_user(user_id) - ) - - # Filter down to the relevant rooms - account_data_by_room_map = { - room_id: account_data_map - for room_id, account_data_map in account_data_by_room_map.items() - if room_id in relevant_room_ids - } - - return SlidingSyncResult.Extensions.AccountDataExtension( - global_account_data_map=global_account_data_map, - account_data_by_room_map=account_data_by_room_map, - ) - - async def get_receipts_extension_response( - self, - sync_config: SlidingSyncConfig, - actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList], - actual_room_ids: Set[str], - actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult], - receipts_request: SlidingSyncConfig.Extensions.ReceiptsExtension, - to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken], - ) -> Optional[SlidingSyncResult.Extensions.ReceiptsExtension]: - """Handle Receipts extension (MSC3960) - - Args: - sync_config: Sync configuration - actual_lists: Sliding window API. A map of list key to list results in the - Sliding Sync response. - actual_room_ids: The actual room IDs in the the Sliding Sync response. - actual_room_response_map: A map of room ID to room results in the the - Sliding Sync response. - account_data_request: The account_data extension from the request - to_token: The point in the stream to sync up to. - from_token: The point in the stream to sync from. - """ - # Skip if the extension is not enabled - if not receipts_request.enabled: - return None - - relevant_room_ids = self.find_relevant_room_ids_for_extension( - requested_lists=receipts_request.lists, - requested_room_ids=receipts_request.rooms, - actual_lists=actual_lists, - actual_room_ids=actual_room_ids, - ) - - room_id_to_receipt_map: Dict[str, JsonMapping] = {} - if len(relevant_room_ids) > 0: - # TODO: Take connection tracking into account so that when a room comes back - # into range we can send the receipts that were missed. - receipt_source = self.event_sources.sources.receipt - receipts, _ = await receipt_source.get_new_events( - user=sync_config.user, - from_key=( - from_token.stream_token.receipt_key - if from_token - else MultiWriterStreamToken(stream=0) - ), - to_key=to_token.receipt_key, - # This is a dummy value and isn't used in the function - limit=0, - room_ids=relevant_room_ids, - is_guest=False, - ) - - for receipt in receipts: - # These fields should exist for every receipt - room_id = receipt["room_id"] - type = receipt["type"] - content = receipt["content"] - - # For `inital: True` rooms, we only want to include receipts for events - # in the timeline. - room_result = actual_room_response_map.get(room_id) - if room_result is not None: - if room_result.initial: - # TODO: In the future, it would be good to fetch less receipts - # out of the database in the first place but we would need to - # add a new `event_id` index to `receipts_linearized`. - relevant_event_ids = [ - event.event_id for event in room_result.timeline_events - ] - - assert isinstance(content, dict) - content = { - event_id: content_value - for event_id, content_value in content.items() - if event_id in relevant_event_ids - } - - room_id_to_receipt_map[room_id] = {"type": type, "content": content} - - return SlidingSyncResult.Extensions.ReceiptsExtension( - room_id_to_receipt_map=room_id_to_receipt_map, - ) - - async def get_typing_extension_response( - self, - sync_config: SlidingSyncConfig, - actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList], - actual_room_ids: Set[str], - actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult], - typing_request: SlidingSyncConfig.Extensions.TypingExtension, - to_token: StreamToken, - from_token: Optional[SlidingSyncStreamToken], - ) -> Optional[SlidingSyncResult.Extensions.TypingExtension]: - """Handle Typing Notification extension (MSC3961) - - Args: - sync_config: Sync configuration - actual_lists: Sliding window API. A map of list key to list results in the - Sliding Sync response. - actual_room_ids: The actual room IDs in the the Sliding Sync response. - actual_room_response_map: A map of room ID to room results in the the - Sliding Sync response. - account_data_request: The account_data extension from the request - to_token: The point in the stream to sync up to. - from_token: The point in the stream to sync from. - """ - # Skip if the extension is not enabled - if not typing_request.enabled: - return None - - relevant_room_ids = self.find_relevant_room_ids_for_extension( - requested_lists=typing_request.lists, - requested_room_ids=typing_request.rooms, - actual_lists=actual_lists, - actual_room_ids=actual_room_ids, - ) - - room_id_to_typing_map: Dict[str, JsonMapping] = {} - if len(relevant_room_ids) > 0: - # Note: We don't need to take connection tracking into account for typing - # notifications because they'll get anything still relevant and hasn't timed - # out when the room comes into range. We consider the gap where the room - # fell out of range, as long enough for any typing notifications to have - # timed out (it's not worth the 30 seconds of data we may have missed). - typing_source = self.event_sources.sources.typing - typing_notifications, _ = await typing_source.get_new_events( - user=sync_config.user, - from_key=(from_token.stream_token.typing_key if from_token else 0), - to_key=to_token.typing_key, - # This is a dummy value and isn't used in the function - limit=0, - room_ids=relevant_room_ids, - is_guest=False, - ) - - for typing_notification in typing_notifications: - # These fields should exist for every typing notification - room_id = typing_notification["room_id"] - type = typing_notification["type"] - content = typing_notification["content"] - - room_id_to_typing_map[room_id] = {"type": type, "content": content} - - return SlidingSyncResult.Extensions.TypingExtension( - room_id_to_typing_map=room_id_to_typing_map, - ) - - -class HaveSentRoomFlag(Enum): - """Flag for whether we have sent the room down a sliding sync connection. - - The valid state changes here are: - NEVER -> LIVE - LIVE -> PREVIOUSLY - PREVIOUSLY -> LIVE - """ - - # The room has never been sent down (or we have forgotten we have sent it - # down). - NEVER = 1 - - # We have previously sent the room down, but there are updates that we - # haven't sent down. - PREVIOUSLY = 2 - - # We have sent the room down and the client has received all updates. - LIVE = 3 - - -@attr.s(auto_attribs=True, slots=True, frozen=True) -class HaveSentRoom: - """Whether we have sent the room down a sliding sync connection. - - Attributes: - status: Flag of if we have or haven't sent down the room - last_token: If the flag is `PREVIOUSLY` then this is non-null and - contains the last stream token of the last updates we sent down - the room, i.e. we still need to send everything since then to the - client. - """ - - status: HaveSentRoomFlag - last_token: Optional[RoomStreamToken] - - @staticmethod - def previously(last_token: RoomStreamToken) -> "HaveSentRoom": - """Constructor for `PREVIOUSLY` flag.""" - return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token) - - -HAVE_SENT_ROOM_NEVER = HaveSentRoom(HaveSentRoomFlag.NEVER, None) -HAVE_SENT_ROOM_LIVE = HaveSentRoom(HaveSentRoomFlag.LIVE, None) - - -@attr.s(auto_attribs=True) -class SlidingSyncConnectionStore: - """In-memory store of per-connection state, including what rooms we have - previously sent down a sliding sync connection. - - Note: This is NOT safe to run in a worker setup because connection positions will - point to different sets of rooms on different workers. e.g. for the same connection, - a connection position of 5 might have totally different states on worker A and - worker B. - - One complication that we need to deal with here is needing to handle requests being - resent, i.e. if we sent down a room in a response that the client received, we must - consider the room *not* sent when we get the request again. - - This is handled by using an integer "token", which is returned to the client - as part of the sync token. For each connection we store a mapping from - tokens to the room states, and create a new entry when we send down new - rooms. - - Note that for any given sliding sync connection we will only store a maximum - of two different tokens: the previous token from the request and a new token - sent in the response. When we receive a request with a given token, we then - clear out all other entries with a different token. - - Attributes: - _connections: Mapping from `(user_id, conn_id)` to mapping of `token` - to mapping of room ID to `HaveSentRoom`. - """ - - # `(user_id, conn_id)` -> `token` -> `room_id` -> `HaveSentRoom` - _connections: Dict[Tuple[str, str], Dict[int, Dict[str, HaveSentRoom]]] = ( - attr.Factory(dict) - ) - - async def is_valid_token( - self, sync_config: SlidingSyncConfig, connection_token: int - ) -> bool: - """Return whether the connection token is valid/recognized""" - if connection_token == 0: - return True - - conn_key = self._get_connection_key(sync_config) - return connection_token in self._connections.get(conn_key, {}) - - async def have_sent_room( - self, sync_config: SlidingSyncConfig, connection_token: int, room_id: str - ) -> HaveSentRoom: - """For the given user_id/conn_id/token, return whether we have - previously sent the room down - """ - - conn_key = self._get_connection_key(sync_config) - sync_statuses = self._connections.setdefault(conn_key, {}) - room_status = sync_statuses.get(connection_token, {}).get( - room_id, HAVE_SENT_ROOM_NEVER - ) - - return room_status - - @trace - async def record_rooms( - self, - sync_config: SlidingSyncConfig, - from_token: Optional[SlidingSyncStreamToken], - *, - sent_room_ids: StrCollection, - unsent_room_ids: StrCollection, - ) -> int: - """Record which rooms we have/haven't sent down in a new response - - Attributes: - sync_config - from_token: The since token from the request, if any - sent_room_ids: The set of room IDs that we have sent down as - part of this request (only needs to be ones we didn't - previously sent down). - unsent_room_ids: The set of room IDs that have had updates - since the `from_token`, but which were not included in - this request - """ - prev_connection_token = 0 - if from_token is not None: - prev_connection_token = from_token.connection_position - - # If there are no changes then this is a noop. - if not sent_room_ids and not unsent_room_ids: - return prev_connection_token - - conn_key = self._get_connection_key(sync_config) - sync_statuses = self._connections.setdefault(conn_key, {}) - - # Generate a new token, removing any existing entries in that token - # (which can happen if requests get resent). - new_store_token = prev_connection_token + 1 - sync_statuses.pop(new_store_token, None) - - # Copy over and update the room mappings. - new_room_statuses = dict(sync_statuses.get(prev_connection_token, {})) - - # Whether we have updated the `new_room_statuses`, if we don't by the - # end we can treat this as a noop. - have_updated = False - for room_id in sent_room_ids: - new_room_statuses[room_id] = HAVE_SENT_ROOM_LIVE - have_updated = True - - # Whether we add/update the entries for unsent rooms depends on the - # existing entry: - # - LIVE: We have previously sent down everything up to - # `last_room_token, so we update the entry to be `PREVIOUSLY` with - # `last_room_token`. - # - PREVIOUSLY: We have previously sent down everything up to *a* - # given token, so we don't need to update the entry. - # - NEVER: We have never previously sent down the room, and we haven't - # sent anything down this time either so we leave it as NEVER. - - # Work out the new state for unsent rooms that were `LIVE`. - if from_token: - new_unsent_state = HaveSentRoom.previously(from_token.stream_token.room_key) - else: - new_unsent_state = HAVE_SENT_ROOM_NEVER - - for room_id in unsent_room_ids: - prev_state = new_room_statuses.get(room_id) - if prev_state is not None and prev_state.status == HaveSentRoomFlag.LIVE: - new_room_statuses[room_id] = new_unsent_state - have_updated = True - - if not have_updated: - return prev_connection_token - - sync_statuses[new_store_token] = new_room_statuses - - return new_store_token - - @trace - async def mark_token_seen( - self, - sync_config: SlidingSyncConfig, - from_token: Optional[SlidingSyncStreamToken], - ) -> None: - """We have received a request with the given token, so we can clear out - any other tokens associated with the connection. - - If there is no from token then we have started afresh, and so we delete - all tokens associated with the device. - """ - # Clear out any tokens for the connection that doesn't match the one - # from the request. - - conn_key = self._get_connection_key(sync_config) - sync_statuses = self._connections.pop(conn_key, {}) - if from_token is None: - return - - sync_statuses = { - connection_token: room_statuses - for connection_token, room_statuses in sync_statuses.items() - if connection_token == from_token.connection_position - } - if sync_statuses: - self._connections[conn_key] = sync_statuses - - @staticmethod - def _get_connection_key(sync_config: SlidingSyncConfig) -> Tuple[str, str]: - """Return a unique identifier for this connection. - - The first part is simply the user ID. - - The second part is generally a combination of device ID and conn_id. - However, both these two are optional (e.g. puppet access tokens don't - have device IDs), so this handles those edge cases. - - We use this over the raw `conn_id` to avoid clashes between different - clients that use the same `conn_id`. Imagine a user uses a web client - that uses `conn_id: main_sync_loop` and an Android client that also has - a `conn_id: main_sync_loop`. - """ - - user_id = sync_config.user.to_string() - - # Only one sliding sync connection is allowed per given conn_id (empty - # or not). - conn_id = sync_config.conn_id or "" - - if sync_config.requester.device_id: - return (user_id, f"D/{sync_config.requester.device_id}/{conn_id}") - - if sync_config.requester.access_token_id: - # If we don't have a device, then the access token ID should be a - # stable ID. - return (user_id, f"A/{sync_config.requester.access_token_id}/{conn_id}") - - # If we have neither then its likely an AS or some weird token. Either - # way we can just fail here. - raise Exception("Cannot use sliding sync with access token type") diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py new file mode 100644
index 0000000000..cb56eb53fc --- /dev/null +++ b/synapse/handlers/sliding_sync/__init__.py
@@ -0,0 +1,1691 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2023 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +import itertools +import logging +from itertools import chain +from typing import TYPE_CHECKING, AbstractSet, Dict, List, Mapping, Optional, Set, Tuple + +from prometheus_client import Histogram +from typing_extensions import assert_never + +from synapse.api.constants import Direction, EventTypes, Membership +from synapse.events import EventBase +from synapse.events.utils import strip_event +from synapse.handlers.relations import BundledAggregations +from synapse.handlers.sliding_sync.extensions import SlidingSyncExtensionHandler +from synapse.handlers.sliding_sync.room_lists import ( + RoomsForUserType, + SlidingSyncRoomLists, +) +from synapse.handlers.sliding_sync.store import SlidingSyncConnectionStore +from synapse.logging.opentracing import ( + SynapseTags, + log_kv, + set_tag, + start_active_span, + tag_args, + trace, +) +from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary +from synapse.storage.databases.main.state_deltas import StateDelta +from synapse.storage.databases.main.stream import PaginateFunction +from synapse.storage.roommember import ( + MemberSummary, +) +from synapse.types import ( + JsonDict, + MutableStateMap, + PersistedEventPosition, + Requester, + RoomStreamToken, + SlidingSyncStreamToken, + StateMap, + StrCollection, + StreamKeyType, + StreamToken, +) +from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES +from synapse.types.handlers.sliding_sync import ( + HaveSentRoomFlag, + MutablePerConnectionState, + PerConnectionState, + RoomSyncConfig, + SlidingSyncConfig, + SlidingSyncResult, + StateValues, +) +from synapse.types.state import StateFilter +from synapse.util.async_helpers import concurrently_execute +from synapse.visibility import filter_events_for_client + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +sync_processing_time = Histogram( + "synapse_sliding_sync_processing_time", + "Time taken to generate a sliding sync response, ignoring wait times.", + ["initial"], +) + +# Limit the number of state_keys we should remember sending down the connection for each +# (room_id, user_id). We don't want to store and pull out too much data in the database. +# +# 100 is an arbitrary but small-ish number. The idea is that we probably won't send down +# too many redundant member state events (that the client already knows about) for a +# given ongoing conversation if we keep 100 around. Most rooms don't have 100 members +# anyway and it takes a while to cycle through 100 members. +MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER = 100 + + +class SlidingSyncHandler: + def __init__(self, hs: "HomeServer"): + self.clock = hs.get_clock() + self.store = hs.get_datastores().main + self.storage_controllers = hs.get_storage_controllers() + self.auth_blocking = hs.get_auth_blocking() + self.notifier = hs.get_notifier() + self.event_sources = hs.get_event_sources() + self.relations_handler = hs.get_relations_handler() + self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync + self.is_mine_id = hs.is_mine_id + + self.connection_store = SlidingSyncConnectionStore(self.store) + self.extensions = SlidingSyncExtensionHandler(hs) + self.room_lists = SlidingSyncRoomLists(hs) + + async def wait_for_sync_for_user( + self, + requester: Requester, + sync_config: SlidingSyncConfig, + from_token: Optional[SlidingSyncStreamToken] = None, + timeout_ms: int = 0, + ) -> SlidingSyncResult: + """ + Get the sync for a client if we have new data for it now. Otherwise + wait for new data to arrive on the server. If the timeout expires, then + return an empty sync result. + + Args: + requester: The user making the request + sync_config: Sync configuration + from_token: The point in the stream to sync from. Token of the end of the + previous batch. May be `None` if this is the initial sync request. + timeout_ms: The time in milliseconds to wait for new data to arrive. If 0, + we will immediately but there might not be any new data so we just return an + empty response. + """ + # If the user is not part of the mau group, then check that limits have + # not been exceeded (if not part of the group by this point, almost certain + # auth_blocking will occur) + await self.auth_blocking.check_auth_blocking(requester=requester) + + # If we're working with a user-provided token, we need to make sure to wait for + # this worker to catch up with the token so we don't skip past any incoming + # events or future events if the user is nefariously, manually modifying the + # token. + if from_token is not None: + # We need to make sure this worker has caught up with the token. If + # this returns false, it means we timed out waiting, and we should + # just return an empty response. + before_wait_ts = self.clock.time_msec() + if not await self.notifier.wait_for_stream_token(from_token.stream_token): + logger.warning( + "Timed out waiting for worker to catch up. Returning empty response" + ) + return SlidingSyncResult.empty(from_token) + + # If we've spent significant time waiting to catch up, take it off + # the timeout. + after_wait_ts = self.clock.time_msec() + if after_wait_ts - before_wait_ts > 1_000: + timeout_ms -= after_wait_ts - before_wait_ts + timeout_ms = max(timeout_ms, 0) + + # We're going to respond immediately if the timeout is 0 or if this is an + # initial sync (without a `from_token`) so we can avoid calling + # `notifier.wait_for_events()`. + if timeout_ms == 0 or from_token is None: + now_token = self.event_sources.get_current_token() + result = await self.current_sync_for_user( + sync_config, + from_token=from_token, + to_token=now_token, + ) + else: + # Otherwise, we wait for something to happen and report it to the user. + async def current_sync_callback( + before_token: StreamToken, after_token: StreamToken + ) -> SlidingSyncResult: + return await self.current_sync_for_user( + sync_config, + from_token=from_token, + to_token=after_token, + ) + + result = await self.notifier.wait_for_events( + sync_config.user.to_string(), + timeout_ms, + current_sync_callback, + from_token=from_token.stream_token, + ) + + return result + + @trace + async def current_sync_for_user( + self, + sync_config: SlidingSyncConfig, + to_token: StreamToken, + from_token: Optional[SlidingSyncStreamToken] = None, + ) -> SlidingSyncResult: + """ + Generates the response body of a Sliding Sync result, represented as a + `SlidingSyncResult`. + + We fetch data according to the token range (> `from_token` and <= `to_token`). + + Args: + sync_config: Sync configuration + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. Token of the end of the + previous batch. May be `None` if this is the initial sync request. + """ + start_time_s = self.clock.time() + + user_id = sync_config.user.to_string() + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service: + # We no longer support AS users using /sync directly. + # See https://github.com/matrix-org/matrix-doc/issues/1144 + raise NotImplementedError() + + # Get the per-connection state (if any). + # + # Raises an exception if there is a `connection_position` that we don't + # recognize. If we don't do this and the client asks for the full range + # of rooms, we end up sending down all rooms and their state from + # scratch (which can be very slow). By expiring the connection we allow + # the client a chance to do an initial request with a smaller range of + # rooms to get them some results sooner but will end up taking the same + # amount of time (more with round-trips and re-processing) in the end to + # get everything again. + previous_connection_state = ( + await self.connection_store.get_and_clear_connection_positions( + sync_config, from_token + ) + ) + + # Get all of the room IDs that the user should be able to see in the sync + # response + has_lists = sync_config.lists is not None and len(sync_config.lists) > 0 + has_room_subscriptions = ( + sync_config.room_subscriptions is not None + and len(sync_config.room_subscriptions) > 0 + ) + + interested_rooms = await self.room_lists.compute_interested_rooms( + sync_config=sync_config, + previous_connection_state=previous_connection_state, + from_token=from_token.stream_token if from_token else None, + to_token=to_token, + ) + + lists = interested_rooms.lists + relevant_room_map = interested_rooms.relevant_room_map + all_rooms = interested_rooms.all_rooms + room_membership_for_user_map = interested_rooms.room_membership_for_user_map + relevant_rooms_to_send_map = interested_rooms.relevant_rooms_to_send_map + + # Fetch room data + rooms: Dict[str, SlidingSyncResult.RoomResult] = {} + + new_connection_state = previous_connection_state.get_mutable() + + @trace + @tag_args + async def handle_room(room_id: str) -> None: + room_sync_result = await self.get_room_sync_data( + sync_config=sync_config, + previous_connection_state=previous_connection_state, + new_connection_state=new_connection_state, + room_id=room_id, + room_sync_config=relevant_rooms_to_send_map[room_id], + room_membership_for_user_at_to_token=room_membership_for_user_map[ + room_id + ], + from_token=from_token, + to_token=to_token, + newly_joined=room_id in interested_rooms.newly_joined_rooms, + newly_left=room_id in interested_rooms.newly_left_rooms, + is_dm=room_id in interested_rooms.dm_room_ids, + ) + + # Filter out empty room results during incremental sync + if room_sync_result or not from_token: + rooms[room_id] = room_sync_result + + if relevant_rooms_to_send_map: + with start_active_span("sliding_sync.generate_room_entries"): + await concurrently_execute(handle_room, relevant_rooms_to_send_map, 20) + + extensions = await self.extensions.get_extensions_response( + sync_config=sync_config, + actual_lists=lists, + previous_connection_state=previous_connection_state, + new_connection_state=new_connection_state, + # We're purposely using `relevant_room_map` instead of + # `relevant_rooms_to_send_map` here. This needs to be all room_ids we could + # send regardless of whether they have an event update or not. The + # extensions care about more than just normal events in the rooms (like + # account data, read receipts, typing indicators, to-device messages, etc). + actual_room_ids=set(relevant_room_map.keys()), + actual_room_response_map=rooms, + from_token=from_token, + to_token=to_token, + ) + + if has_lists or has_room_subscriptions: + # We now calculate if any rooms outside the range have had updates, + # which we are not sending down. + # + # We *must* record rooms that have had updates, but it is also fine + # to record rooms as having updates even if there might not actually + # be anything new for the user (e.g. due to event filters, events + # having happened after the user left, etc). + if from_token: + # The set of rooms that the client (may) care about, but aren't + # in any list range (or subscribed to). + missing_rooms = all_rooms - relevant_room_map.keys() + + # We now just go and try fetching any events in the above rooms + # to see if anything has happened since the `from_token`. + # + # TODO: Replace this with something faster. When we land the + # sliding sync tables that record the most recent event + # positions we can use that. + unsent_room_ids: StrCollection + if await self.store.have_finished_sliding_sync_background_jobs(): + unsent_room_ids = await ( + self.store.get_rooms_that_have_updates_since_sliding_sync_table( + room_ids=missing_rooms, + from_key=from_token.stream_token.room_key, + ) + ) + else: + missing_event_map_by_room = ( + await self.store.get_room_events_stream_for_rooms( + room_ids=missing_rooms, + from_key=to_token.room_key, + to_key=from_token.stream_token.room_key, + limit=1, + ) + ) + unsent_room_ids = list(missing_event_map_by_room) + + new_connection_state.rooms.record_unsent_rooms( + unsent_room_ids, from_token.stream_token.room_key + ) + + new_connection_state.rooms.record_sent_rooms( + relevant_rooms_to_send_map.keys() + ) + + connection_position = await self.connection_store.record_new_state( + sync_config=sync_config, + from_token=from_token, + new_connection_state=new_connection_state, + ) + elif from_token: + connection_position = from_token.connection_position + else: + # Initial sync without a `from_token` starts at `0` + connection_position = 0 + + sliding_sync_result = SlidingSyncResult( + next_pos=SlidingSyncStreamToken(to_token, connection_position), + lists=lists, + rooms=rooms, + extensions=extensions, + ) + + # Make it easy to find traces for syncs that aren't empty + set_tag(SynapseTags.RESULT_PREFIX + "result", bool(sliding_sync_result)) + set_tag(SynapseTags.FUNC_ARG_PREFIX + "sync_config.user", user_id) + + end_time_s = self.clock.time() + sync_processing_time.labels(from_token is not None).observe( + end_time_s - start_time_s + ) + + return sliding_sync_result + + @trace + async def get_current_state_ids_at( + self, + room_id: str, + room_membership_for_user_at_to_token: RoomsForUserType, + state_filter: StateFilter, + to_token: StreamToken, + ) -> StateMap[str]: + """ + Get current state IDs for the user in the room according to their membership. This + will be the current state at the time of their LEAVE/BAN, otherwise will be the + current state <= to_token. + + Args: + room_id: The room ID to fetch data for + room_membership_for_user_at_token: Membership information for the user + in the room at the time of `to_token`. + to_token: The point in the stream to sync up to. + """ + state_ids: StateMap[str] + # People shouldn't see past their leave/ban event + if room_membership_for_user_at_to_token.membership in ( + Membership.LEAVE, + Membership.BAN, + ): + # TODO: `get_state_ids_at(...)` doesn't take into account the "current + # state". Maybe we need to use + # `get_forward_extremities_for_room_at_stream_ordering(...)` to "Fetch the + # current state at the time." + state_ids = await self.storage_controllers.state.get_state_ids_at( + room_id, + stream_position=to_token.copy_and_replace( + StreamKeyType.ROOM, + room_membership_for_user_at_to_token.event_pos.to_room_stream_token(), + ), + state_filter=state_filter, + # Partially-stated rooms should have all state events except for + # remote membership events. Since we've already excluded + # partially-stated rooms unless `required_state` only has + # `["m.room.member", "$LAZY"]` for membership, we should be able to + # retrieve everything requested. When we're lazy-loading, if there + # are some remote senders in the timeline, we should also have their + # membership event because we had to auth that timeline event. Plus + # we don't want to block the whole sync waiting for this one room. + await_full_state=False, + ) + # Otherwise, we can get the latest current state in the room + else: + state_ids = await self.storage_controllers.state.get_current_state_ids( + room_id, + state_filter, + # Partially-stated rooms should have all state events except for + # remote membership events. Since we've already excluded + # partially-stated rooms unless `required_state` only has + # `["m.room.member", "$LAZY"]` for membership, we should be able to + # retrieve everything requested. When we're lazy-loading, if there + # are some remote senders in the timeline, we should also have their + # membership event because we had to auth that timeline event. Plus + # we don't want to block the whole sync waiting for this one room. + await_full_state=False, + ) + # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token` + + return state_ids + + @trace + async def get_current_state_at( + self, + room_id: str, + room_membership_for_user_at_to_token: RoomsForUserType, + state_filter: StateFilter, + to_token: StreamToken, + ) -> StateMap[EventBase]: + """ + Get current state for the user in the room according to their membership. This + will be the current state at the time of their LEAVE/BAN, otherwise will be the + current state <= to_token. + + Args: + room_id: The room ID to fetch data for + room_membership_for_user_at_token: Membership information for the user + in the room at the time of `to_token`. + to_token: The point in the stream to sync up to. + """ + state_ids = await self.get_current_state_ids_at( + room_id=room_id, + room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, + state_filter=state_filter, + to_token=to_token, + ) + + events = await self.store.get_events_as_list(list(state_ids.values())) + + state_map = {} + for event in events: + state_map[(event.type, event.state_key)] = event + + return state_map + + @trace + async def get_current_state_deltas_for_room( + self, + room_id: str, + room_membership_for_user_at_to_token: RoomsForUserType, + from_token: RoomStreamToken, + to_token: RoomStreamToken, + ) -> List[StateDelta]: + """ + Get the state deltas between two tokens taking into account the user's + membership. If the user is LEAVE/BAN, we will only get the state deltas up to + their LEAVE/BAN event (inclusive). + + (> `from_token` and <= `to_token`) + """ + membership = room_membership_for_user_at_to_token.membership + # We don't know how to handle `membership` values other than these. The + # code below would need to be updated. + assert membership in ( + Membership.JOIN, + Membership.INVITE, + Membership.KNOCK, + Membership.LEAVE, + Membership.BAN, + ) + + # People shouldn't see past their leave/ban event + if membership in ( + Membership.LEAVE, + Membership.BAN, + ): + to_bound = ( + room_membership_for_user_at_to_token.event_pos.to_room_stream_token() + ) + # If we are participating in the room, we can get the latest current state in + # the room + elif membership == Membership.JOIN: + to_bound = to_token + # We can only rely on the stripped state included in the invite/knock event + # itself so there will never be any state deltas to send down. + elif membership in (Membership.INVITE, Membership.KNOCK): + return [] + else: + # We don't know how to handle this type of membership yet + # + # FIXME: We should use `assert_never` here but for some reason + # the exhaustive matching doesn't recognize the `Never` here. + # assert_never(membership) + raise AssertionError( + f"Unexpected membership {membership} that we don't know how to handle yet" + ) + + return await self.store.get_current_state_deltas_for_room( + room_id=room_id, + from_token=from_token, + to_token=to_bound, + ) + + @trace + async def get_room_sync_data( + self, + sync_config: SlidingSyncConfig, + previous_connection_state: "PerConnectionState", + new_connection_state: "MutablePerConnectionState", + room_id: str, + room_sync_config: RoomSyncConfig, + room_membership_for_user_at_to_token: RoomsForUserType, + from_token: Optional[SlidingSyncStreamToken], + to_token: StreamToken, + newly_joined: bool, + newly_left: bool, + is_dm: bool, + ) -> SlidingSyncResult.RoomResult: + """ + Fetch room data for the sync response. + + We fetch data according to the token range (> `from_token` and <= `to_token`). + + Args: + user: User to fetch data for + room_id: The room ID to fetch data for + room_sync_config: Config for what data we should fetch for a room in the + sync response. + room_membership_for_user_at_to_token: Membership information for the user + in the room at the time of `to_token`. + from_token: The point in the stream to sync from. + to_token: The point in the stream to sync up to. + newly_joined: If the user has newly joined the room + newly_left: If the user has newly left the room + is_dm: Whether the room is a DM room + """ + user = sync_config.user + + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "membership", + room_membership_for_user_at_to_token.membership, + ) + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "timeline_limit", + room_sync_config.timeline_limit, + ) + + # Handle state resets. For example, if we see + # `room_membership_for_user_at_to_token.event_id=None and + # room_membership_for_user_at_to_token.membership is not None`, we should + # indicate to the client that a state reset happened. Perhaps we should indicate + # this by setting `initial: True` and empty `required_state: []`. + state_reset_out_of_room = False + if ( + room_membership_for_user_at_to_token.event_id is None + and room_membership_for_user_at_to_token.membership is not None + ): + # We only expect the `event_id` to be `None` if you've been state reset out + # of the room (meaning you're no longer in the room). We could put this as + # part of the if-statement above but we want to handle every case where + # `event_id` is `None`. + assert room_membership_for_user_at_to_token.membership is Membership.LEAVE + + state_reset_out_of_room = True + + prev_room_sync_config = previous_connection_state.room_configs.get(room_id) + + # Determine whether we should limit the timeline to the token range. + # + # We should return historical messages (before token range) in the + # following cases because we want clients to be able to show a basic + # screen of information: + # + # - Initial sync (because no `from_token` to limit us anyway) + # - When users `newly_joined` + # - For an incremental sync where we haven't sent it down this + # connection before + # + # Relevant spec issue: + # https://github.com/matrix-org/matrix-spec/issues/1917 + # + # XXX: Odd behavior - We also check if the `timeline_limit` has increased, if so + # we ignore the from bound for the timeline to send down a larger chunk of + # history and set `unstable_expanded_timeline` to true. This is only being added + # to match the behavior of the Sliding Sync proxy as we expect the ElementX + # client to feel a certain way and be able to trickle in a full page of timeline + # messages to fill up the screen. This is a bit different to the behavior of the + # Sliding Sync proxy (which sets initial=true, but then doesn't send down the + # full state again), but existing apps, e.g. ElementX, just need `limited` set. + # We don't explicitly set `limited` but this will be the case for any room that + # has more history than we're trying to pull out. Using + # `unstable_expanded_timeline` allows us to avoid contaminating what `initial` + # or `limited` mean for clients that interpret them correctly. In future this + # behavior is almost certainly going to change. + # + from_bound = None + initial = True + ignore_timeline_bound = False + if from_token and not newly_joined and not state_reset_out_of_room: + room_status = previous_connection_state.rooms.have_sent_room(room_id) + if room_status.status == HaveSentRoomFlag.LIVE: + from_bound = from_token.stream_token.room_key + initial = False + elif room_status.status == HaveSentRoomFlag.PREVIOUSLY: + assert room_status.last_token is not None + from_bound = room_status.last_token + initial = False + elif room_status.status == HaveSentRoomFlag.NEVER: + from_bound = None + initial = True + else: + assert_never(room_status.status) + + log_kv({"sliding_sync.room_status": room_status}) + + if prev_room_sync_config is not None: + # Check if the timeline limit has increased, if so ignore the + # timeline bound and record the change (see "XXX: Odd behavior" + # above). + if ( + prev_room_sync_config.timeline_limit + < room_sync_config.timeline_limit + ): + ignore_timeline_bound = True + + log_kv( + { + "sliding_sync.from_bound": from_bound, + "sliding_sync.initial": initial, + "sliding_sync.ignore_timeline_bound": ignore_timeline_bound, + } + ) + + # Assemble the list of timeline events + # + # FIXME: It would be nice to make the `rooms` response more uniform regardless of + # membership. Currently, we have to make all of these optional because + # `invite`/`knock` rooms only have `stripped_state`. See + # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 + timeline_events: List[EventBase] = [] + bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None + limited: Optional[bool] = None + prev_batch_token: Optional[StreamToken] = None + num_live: Optional[int] = None + if ( + room_sync_config.timeline_limit > 0 + # No timeline for invite/knock rooms (just `stripped_state`) + and room_membership_for_user_at_to_token.membership + not in (Membership.INVITE, Membership.KNOCK) + ): + limited = False + # We want to start off using the `to_token` (vs `from_token`) because we look + # backwards from the `to_token` up to the `timeline_limit` and we might not + # reach the `from_token` before we hit the limit. We will update the room stream + # position once we've fetched the events to point to the earliest event fetched. + prev_batch_token = to_token + + # We're going to paginate backwards from the `to_token` + to_bound = to_token.room_key + # People shouldn't see past their leave/ban event + if room_membership_for_user_at_to_token.membership in ( + Membership.LEAVE, + Membership.BAN, + ): + to_bound = room_membership_for_user_at_to_token.event_pos.to_room_stream_token() + + timeline_from_bound = from_bound + if ignore_timeline_bound: + timeline_from_bound = None + + # For initial `/sync` (and other historical scenarios mentioned above), we + # want to view a historical section of the timeline; to fetch events by + # `topological_ordering` (best representation of the room DAG as others were + # seeing it at the time). This also aligns with the order that `/messages` + # returns events in. + # + # For incremental `/sync`, we want to get all updates for rooms since + # the last `/sync` (regardless if those updates arrived late or happened + # a while ago in the past); to fetch events by `stream_ordering` (in the + # order they were received by the server). + # + # Relevant spec issue: https://github.com/matrix-org/matrix-spec/issues/1917 + # + # FIXME: Using workaround for mypy, + # https://github.com/python/mypy/issues/10740#issuecomment-1997047277 and + # https://github.com/python/mypy/issues/17479 + paginate_room_events_by_topological_ordering: PaginateFunction = ( + self.store.paginate_room_events_by_topological_ordering + ) + paginate_room_events_by_stream_ordering: PaginateFunction = ( + self.store.paginate_room_events_by_stream_ordering + ) + pagination_method: PaginateFunction = ( + # Use `topographical_ordering` for historical events + paginate_room_events_by_topological_ordering + if timeline_from_bound is None + # Use `stream_ordering` for updates + else paginate_room_events_by_stream_ordering + ) + timeline_events, new_room_key, limited = await pagination_method( + room_id=room_id, + # The bounds are reversed so we can paginate backwards + # (from newer to older events) starting at to_bound. + # This ensures we fill the `limit` with the newest events first, + from_key=to_bound, + to_key=timeline_from_bound, + direction=Direction.BACKWARDS, + limit=room_sync_config.timeline_limit, + ) + + # We want to return the events in ascending order (the last event is the + # most recent). + timeline_events.reverse() + + # Make sure we don't expose any events that the client shouldn't see + timeline_events = await filter_events_for_client( + self.storage_controllers, + user.to_string(), + timeline_events, + is_peeking=room_membership_for_user_at_to_token.membership + != Membership.JOIN, + filter_send_to_client=True, + ) + # TODO: Filter out `EventTypes.CallInvite` in public rooms, + # see https://github.com/element-hq/synapse/issues/17359 + + # TODO: Handle timeline gaps (`get_timeline_gaps()`) + + # Determine how many "live" events we have (events within the given token range). + # + # This is mostly useful to determine whether a given @mention event should + # make a noise or not. Clients cannot rely solely on the absence of + # `initial: true` to determine live events because if a room not in the + # sliding window bumps into the window because of an @mention it will have + # `initial: true` yet contain a single live event (with potentially other + # old events in the timeline) + num_live = 0 + if from_token is not None: + for timeline_event in reversed(timeline_events): + # This fields should be present for all persisted events + assert timeline_event.internal_metadata.stream_ordering is not None + assert timeline_event.internal_metadata.instance_name is not None + + persisted_position = PersistedEventPosition( + instance_name=timeline_event.internal_metadata.instance_name, + stream=timeline_event.internal_metadata.stream_ordering, + ) + if persisted_position.persisted_after( + from_token.stream_token.room_key + ): + num_live += 1 + else: + # Since we're iterating over the timeline events in + # reverse-chronological order, we can break once we hit an event + # that's not live. In the future, we could potentially optimize + # this more with a binary search (bisect). + break + + # If the timeline is `limited=True`, the client does not have all events + # necessary to calculate aggregations themselves. + if limited: + bundled_aggregations = ( + await self.relations_handler.get_bundled_aggregations( + timeline_events, user.to_string() + ) + ) + + # Update the `prev_batch_token` to point to the position that allows us to + # keep paginating backwards from the oldest event we return in the timeline. + prev_batch_token = prev_batch_token.copy_and_replace( + StreamKeyType.ROOM, new_room_key + ) + + # Figure out any stripped state events for invite/knocks. This allows the + # potential joiner to identify the room. + stripped_state: List[JsonDict] = [] + if room_membership_for_user_at_to_token.membership in ( + Membership.INVITE, + Membership.KNOCK, + ): + # This should never happen. If someone is invited/knocked on room, then + # there should be an event for it. + assert room_membership_for_user_at_to_token.event_id is not None + + invite_or_knock_event = await self.store.get_event( + room_membership_for_user_at_to_token.event_id + ) + + stripped_state = [] + if invite_or_knock_event.membership == Membership.INVITE: + invite_state = invite_or_knock_event.unsigned.get( + "invite_room_state", [] + ) + if not isinstance(invite_state, list): + invite_state = [] + + stripped_state.extend(invite_state) + elif invite_or_knock_event.membership == Membership.KNOCK: + knock_state = invite_or_knock_event.unsigned.get("knock_room_state", []) + if not isinstance(knock_state, list): + knock_state = [] + + stripped_state.extend(knock_state) + + stripped_state.append(strip_event(invite_or_knock_event)) + + # Get the changes to current state in the token range from the + # `current_state_delta_stream` table. + # + # For incremental syncs, we can do this first to determine if something relevant + # has changed and strategically avoid fetching other costly things. + room_state_delta_id_map: MutableStateMap[str] = {} + name_event_id: Optional[str] = None + membership_changed = False + name_changed = False + avatar_changed = False + if initial: + # Check whether the room has a name set + name_state_ids = await self.get_current_state_ids_at( + room_id=room_id, + room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, + state_filter=StateFilter.from_types([(EventTypes.Name, "")]), + to_token=to_token, + ) + name_event_id = name_state_ids.get((EventTypes.Name, "")) + else: + assert from_bound is not None + + # TODO: Limit the number of state events we're about to send down + # the room, if its too many we should change this to an + # `initial=True`? + + # For the case of rejecting remote invites, the leave event won't be + # returned by `get_current_state_deltas_for_room`. This is due to the current + # state only being filled out for rooms the server is in, and so doesn't pick + # up out-of-band leaves (including locally rejected invites) as these events + # are outliers and not added to the `current_state_delta_stream`. + # + # We rely on being explicitly told that the room has been `newly_left` to + # ensure we extract the out-of-band leave. + if newly_left and room_membership_for_user_at_to_token.event_id is not None: + membership_changed = True + leave_event = await self.store.get_event( + room_membership_for_user_at_to_token.event_id + ) + state_key = leave_event.get_state_key() + if state_key is not None: + room_state_delta_id_map[(leave_event.type, state_key)] = ( + room_membership_for_user_at_to_token.event_id + ) + + deltas = await self.get_current_state_deltas_for_room( + room_id=room_id, + room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, + from_token=from_bound, + to_token=to_token.room_key, + ) + for delta in deltas: + # TODO: Handle state resets where event_id is None + if delta.event_id is not None: + room_state_delta_id_map[(delta.event_type, delta.state_key)] = ( + delta.event_id + ) + + if delta.event_type == EventTypes.Member: + membership_changed = True + elif delta.event_type == EventTypes.Name and delta.state_key == "": + name_changed = True + elif ( + delta.event_type == EventTypes.RoomAvatar and delta.state_key == "" + ): + avatar_changed = True + + # We only need the room summary for calculating heroes, however if we do + # fetch it then we can use it to calculate `joined_count` and + # `invited_count`. + room_membership_summary: Optional[Mapping[str, MemberSummary]] = None + + # `heroes` are required if the room name is not set. + # + # Note: When you're the first one on your server to be invited to a new room + # over federation, we only have access to some stripped state in + # `event.unsigned.invite_room_state` which currently doesn't include `heroes`, + # see https://github.com/matrix-org/matrix-spec/issues/380. This means that + # clients won't be able to calculate the room name when necessary and just a + # pitfall we have to deal with until that spec issue is resolved. + hero_user_ids: List[str] = [] + # TODO: Should we also check for `EventTypes.CanonicalAlias` + # (`m.room.canonical_alias`) as a fallback for the room name? see + # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1671260153 + # + # We need to fetch the `heroes` if the room name is not set. But we only need to + # get them on initial syncs (or the first time we send down the room) or if the + # membership has changed which may change the heroes. + if name_event_id is None and (initial or (not initial and membership_changed)): + # We need the room summary to extract the heroes from + if room_membership_for_user_at_to_token.membership != Membership.JOIN: + # TODO: Figure out how to get the membership summary for left/banned rooms + # For invite/knock rooms we don't include the information. + room_membership_summary = {} + else: + room_membership_summary = await self.store.get_room_summary(room_id) + # TODO: Reverse/rewind back to the `to_token` + + hero_user_ids = extract_heroes_from_room_summary( + room_membership_summary, me=user.to_string() + ) + + # Fetch the membership counts for rooms we're joined to. + # + # Similarly to other metadata, we only need to calculate the member + # counts if this is an initial sync or the memberships have changed. + joined_count: Optional[int] = None + invited_count: Optional[int] = None + if ( + initial or membership_changed + ) and room_membership_for_user_at_to_token.membership == Membership.JOIN: + # If we have the room summary (because we calculated heroes above) + # then we can simply pull the counts from there. + if room_membership_summary is not None: + empty_membership_summary = MemberSummary([], 0) + + joined_count = room_membership_summary.get( + Membership.JOIN, empty_membership_summary + ).count + + invited_count = room_membership_summary.get( + Membership.INVITE, empty_membership_summary + ).count + else: + member_counts = await self.store.get_member_counts(room_id) + joined_count = member_counts.get(Membership.JOIN, 0) + invited_count = member_counts.get(Membership.INVITE, 0) + + # Fetch the `required_state` for the room + # + # No `required_state` for invite/knock rooms (just `stripped_state`) + # + # FIXME: It would be nice to make the `rooms` response more uniform regardless + # of membership. Currently, we have to make this optional because + # `invite`/`knock` rooms only have `stripped_state`. See + # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 + # + # Calculate the `StateFilter` based on the `required_state` for the room + required_state_filter = StateFilter.none() + # The requested `required_state_map` with the lazy membership expanded and + # `$ME` replaced with the user's ID. This allows us to see what membership we've + # sent down to the client in the next request. + # + # Make a copy so we can modify it. Still need to be careful to make a copy of + # the state key sets if we want to add/remove from them. We could make a deep + # copy but this saves us some work. + expanded_required_state_map = dict(room_sync_config.required_state_map) + if room_membership_for_user_at_to_token.membership not in ( + Membership.INVITE, + Membership.KNOCK, + ): + # If we have a double wildcard ("*", "*") in the `required_state`, we need + # to fetch all state for the room + # + # Note: MSC3575 describes different behavior to how we're handling things + # here but since it's not wrong to return more state than requested + # (`required_state` is just the minimum requested), it doesn't matter if we + # include more than client wanted. This complexity is also under scrutiny, + # see + # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1185109050 + # + # > One unique exception is when you request all state events via ["*", "*"]. When used, + # > all state events are returned by default, and additional entries FILTER OUT the returned set + # > of state events. These additional entries cannot use '*' themselves. + # > For example, ["*", "*"], ["m.room.member", "@alice:example.com"] will _exclude_ every m.room.member + # > event _except_ for @alice:example.com, and include every other state event. + # > In addition, ["*", "*"], ["m.space.child", "*"] is an error, the m.space.child filter is not + # > required as it would have been returned anyway. + # > + # > -- MSC3575 (https://github.com/matrix-org/matrix-spec-proposals/pull/3575) + if StateValues.WILDCARD in room_sync_config.required_state_map.get( + StateValues.WILDCARD, set() + ): + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "required_state_wildcard", + True, + ) + required_state_filter = StateFilter.all() + # TODO: `StateFilter` currently doesn't support wildcard event types. We're + # currently working around this by returning all state to the client but it + # would be nice to fetch less from the database and return just what the + # client wanted. + elif ( + room_sync_config.required_state_map.get(StateValues.WILDCARD) + is not None + ): + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "required_state_wildcard_event_type", + True, + ) + required_state_filter = StateFilter.all() + else: + required_state_types: List[Tuple[str, Optional[str]]] = [] + num_wild_state_keys = 0 + lazy_load_room_members = False + num_others = 0 + for ( + state_type, + state_key_set, + ) in room_sync_config.required_state_map.items(): + for state_key in state_key_set: + if state_key == StateValues.WILDCARD: + num_wild_state_keys += 1 + # `None` is a wildcard in the `StateFilter` + required_state_types.append((state_type, None)) + # We need to fetch all relevant people when we're lazy-loading membership + elif ( + state_type == EventTypes.Member + and state_key == StateValues.LAZY + ): + lazy_load_room_members = True + + # Everyone in the timeline is relevant + timeline_membership: Set[str] = set() + if timeline_events is not None: + for timeline_event in timeline_events: + # Anyone who sent a message is relevant + timeline_membership.add(timeline_event.sender) + + # We also care about invite, ban, kick, targets, + # etc. + if timeline_event.type == EventTypes.Member: + timeline_membership.add( + timeline_event.state_key + ) + + # Update the required state filter so we pick up the new + # membership + for user_id in timeline_membership: + required_state_types.append( + (EventTypes.Member, user_id) + ) + + # Add an explicit entry for each user in the timeline + # + # Make a new set or copy of the state key set so we can + # modify it without affecting the original + # `required_state_map` + expanded_required_state_map[EventTypes.Member] = ( + expanded_required_state_map.get( + EventTypes.Member, set() + ) + | timeline_membership + ) + elif state_key == StateValues.ME: + num_others += 1 + required_state_types.append((state_type, user.to_string())) + # Replace `$ME` with the user's ID so we can deduplicate + # when someone requests the same state with `$ME` or with + # their user ID. + # + # Make a new set or copy of the state key set so we can + # modify it without affecting the original + # `required_state_map` + expanded_required_state_map[EventTypes.Member] = ( + expanded_required_state_map.get( + EventTypes.Member, set() + ) + | {user.to_string()} + ) + else: + num_others += 1 + required_state_types.append((state_type, state_key)) + + set_tag( + SynapseTags.FUNC_ARG_PREFIX + + "required_state_wildcard_state_key_count", + num_wild_state_keys, + ) + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "required_state_lazy", + lazy_load_room_members, + ) + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "required_state_other_count", + num_others, + ) + + required_state_filter = StateFilter.from_types(required_state_types) + + # We need this base set of info for the response so let's just fetch it along + # with the `required_state` for the room + hero_room_state = [ + (EventTypes.Member, hero_user_id) for hero_user_id in hero_user_ids + ] + meta_room_state = list(hero_room_state) + if initial or name_changed: + meta_room_state.append((EventTypes.Name, "")) + if initial or avatar_changed: + meta_room_state.append((EventTypes.RoomAvatar, "")) + + state_filter = StateFilter.all() + if required_state_filter != StateFilter.all(): + state_filter = StateFilter( + types=StateFilter.from_types( + chain(meta_room_state, required_state_filter.to_types()) + ).types, + include_others=required_state_filter.include_others, + ) + + # The required state map to store in the room sync config, if it has + # changed. + changed_required_state_map: Optional[Mapping[str, AbstractSet[str]]] = None + + # We can return all of the state that was requested if this was the first + # time we've sent the room down this connection. + room_state: StateMap[EventBase] = {} + if initial: + room_state = await self.get_current_state_at( + room_id=room_id, + room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, + state_filter=state_filter, + to_token=to_token, + ) + else: + assert from_bound is not None + + if prev_room_sync_config is not None: + # Check if there are any changes to the required state config + # that we need to handle. + changed_required_state_map, added_state_filter = ( + _required_state_changes( + user.to_string(), + prev_required_state_map=prev_room_sync_config.required_state_map, + request_required_state_map=expanded_required_state_map, + state_deltas=room_state_delta_id_map, + ) + ) + + if added_state_filter: + # Some state entries got added, so we pull out the current + # state for them. If we don't do this we'd only send down new deltas. + state_ids = await self.get_current_state_ids_at( + room_id=room_id, + room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, + state_filter=added_state_filter, + to_token=to_token, + ) + room_state_delta_id_map.update(state_ids) + + events = await self.store.get_events( + state_filter.filter_state(room_state_delta_id_map).values() + ) + room_state = {(s.type, s.state_key): s for s in events.values()} + + # If the membership changed and we have to get heroes, get the remaining + # heroes from the state + if hero_user_ids: + hero_membership_state = await self.get_current_state_at( + room_id=room_id, + room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, + state_filter=StateFilter.from_types(hero_room_state), + to_token=to_token, + ) + room_state.update(hero_membership_state) + + required_room_state: StateMap[EventBase] = {} + if required_state_filter != StateFilter.none(): + required_room_state = required_state_filter.filter_state(room_state) + + # Find the room name and avatar from the state + room_name: Optional[str] = None + # TODO: Should we also check for `EventTypes.CanonicalAlias` + # (`m.room.canonical_alias`) as a fallback for the room name? see + # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1671260153 + name_event = room_state.get((EventTypes.Name, "")) + if name_event is not None: + room_name = name_event.content.get("name") + + room_avatar: Optional[str] = None + avatar_event = room_state.get((EventTypes.RoomAvatar, "")) + if avatar_event is not None: + room_avatar = avatar_event.content.get("url") + + # Assemble heroes: extract the info from the state we just fetched + heroes: List[SlidingSyncResult.RoomResult.StrippedHero] = [] + for hero_user_id in hero_user_ids: + member_event = room_state.get((EventTypes.Member, hero_user_id)) + if member_event is not None: + heroes.append( + SlidingSyncResult.RoomResult.StrippedHero( + user_id=hero_user_id, + display_name=member_event.content.get("displayname"), + avatar_url=member_event.content.get("avatar_url"), + ) + ) + + # Figure out the last bump event in the room. If the bump stamp hasn't + # changed we omit it from the response. + bump_stamp = None + + always_return_bump_stamp = ( + # We use the membership event position for any non-join + room_membership_for_user_at_to_token.membership != Membership.JOIN + # We didn't fetch any timeline events but we should still check for + # a bump_stamp that might be somewhere + or limited is None + # There might be a bump event somewhere before the timeline events + # that we fetched, that we didn't previously send down + or limited is True + # Always give the client some frame of reference if this is the + # first time they are seeing the room down the connection + or initial + ) + + # If we're joined to the room, we need to find the last bump event before the + # `to_token` + if room_membership_for_user_at_to_token.membership == Membership.JOIN: + # Try and get a bump stamp + new_bump_stamp = await self._get_bump_stamp( + room_id, + to_token, + timeline_events, + check_outside_timeline=always_return_bump_stamp, + ) + if new_bump_stamp is not None: + bump_stamp = new_bump_stamp + + if bump_stamp is None and always_return_bump_stamp: + # By default, just choose the membership event position for any non-join membership + bump_stamp = room_membership_for_user_at_to_token.event_pos.stream + + if bump_stamp is not None and bump_stamp < 0: + # We never want to send down negative stream orderings, as you can't + # sensibly compare positive and negative stream orderings (they have + # different meanings). + # + # A negative bump stamp here can only happen if the stream ordering + # of the membership event is negative (and there are no further bump + # stamps), which can happen if the server leaves and deletes a room, + # and then rejoins it. + # + # To deal with this, we just set the bump stamp to zero, which will + # shove this room to the bottom of the list. This is OK as the + # moment a new message happens in the room it will get put into a + # sensible order again. + bump_stamp = 0 + + room_sync_required_state_map_to_persist: Mapping[str, AbstractSet[str]] = ( + expanded_required_state_map + ) + if changed_required_state_map: + room_sync_required_state_map_to_persist = changed_required_state_map + + # Record the `room_sync_config` if we're `ignore_timeline_bound` (which means + # that the `timeline_limit` has increased) + unstable_expanded_timeline = False + if ignore_timeline_bound: + # FIXME: We signal the fact that we're sending down more events to + # the client by setting `unstable_expanded_timeline` to true (see + # "XXX: Odd behavior" above). + unstable_expanded_timeline = True + + new_connection_state.room_configs[room_id] = RoomSyncConfig( + timeline_limit=room_sync_config.timeline_limit, + required_state_map=room_sync_required_state_map_to_persist, + ) + elif prev_room_sync_config is not None: + # If the result is `limited` then we need to record that the + # `timeline_limit` has been reduced, as when/if the client later requests + # more timeline then we have more data to send. + # + # Otherwise (when not `limited`) we don't need to record that the + # `timeline_limit` has been reduced, as the *effective* `timeline_limit` + # (i.e. the amount of timeline we have previously sent to the client) is at + # least the previous `timeline_limit`. + # + # This is to handle the case where the `timeline_limit` e.g. goes from 10 to + # 5 to 10 again (without any timeline gaps), where there's no point sending + # down the initial historical chunk events when the `timeline_limit` is + # increased as the client already has the 10 previous events. However, if + # client has a gap in the timeline (i.e. `limited` is True), then we *do* + # need to record the reduced timeline. + # + # TODO: Handle timeline gaps (`get_timeline_gaps()`) - This is separate from + # the gaps we might see on the client because a response was `limited` we're + # talking about above. + if ( + limited + and prev_room_sync_config.timeline_limit + > room_sync_config.timeline_limit + ): + new_connection_state.room_configs[room_id] = RoomSyncConfig( + timeline_limit=room_sync_config.timeline_limit, + required_state_map=room_sync_required_state_map_to_persist, + ) + + elif changed_required_state_map is not None: + new_connection_state.room_configs[room_id] = RoomSyncConfig( + timeline_limit=room_sync_config.timeline_limit, + required_state_map=room_sync_required_state_map_to_persist, + ) + + else: + new_connection_state.room_configs[room_id] = RoomSyncConfig( + timeline_limit=room_sync_config.timeline_limit, + required_state_map=room_sync_required_state_map_to_persist, + ) + + set_tag(SynapseTags.RESULT_PREFIX + "initial", initial) + + return SlidingSyncResult.RoomResult( + name=room_name, + avatar=room_avatar, + heroes=heroes, + is_dm=is_dm, + initial=initial, + required_state=list(required_room_state.values()), + timeline_events=timeline_events, + bundled_aggregations=bundled_aggregations, + stripped_state=stripped_state, + prev_batch=prev_batch_token, + limited=limited, + unstable_expanded_timeline=unstable_expanded_timeline, + num_live=num_live, + bump_stamp=bump_stamp, + joined_count=joined_count, + invited_count=invited_count, + # TODO: These are just dummy values. We could potentially just remove these + # since notifications can only really be done correctly on the client anyway + # (encrypted rooms). + notification_count=0, + highlight_count=0, + ) + + @trace + async def _get_bump_stamp( + self, + room_id: str, + to_token: StreamToken, + timeline: List[EventBase], + check_outside_timeline: bool, + ) -> Optional[int]: + """Get a bump stamp for the room, if we have a bump event and it has + changed. + + Args: + room_id + to_token: The upper bound of token to return + timeline: The list of events we have fetched. + limited: If the timeline was limited. + check_outside_timeline: Whether we need to check for bump stamp for + events before the timeline if we didn't find a bump stamp in + the timeline events. + """ + + # First check the timeline events we're returning to see if one of + # those matches. We iterate backwards and take the stream ordering + # of the first event that matches the bump event types. + for timeline_event in reversed(timeline): + if timeline_event.type in SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES: + new_bump_stamp = timeline_event.internal_metadata.stream_ordering + + # All persisted events have a stream ordering + assert new_bump_stamp is not None + + # If we've just joined a remote room, then the last bump event may + # have been backfilled (and so have a negative stream ordering). + # These negative stream orderings can't sensibly be compared, so + # instead we use the membership event position. + if new_bump_stamp > 0: + return new_bump_stamp + + if not check_outside_timeline: + # If we are not a limited sync, then we know the bump stamp can't + # have changed. + return None + + # We can quickly query for the latest bump event in the room using the + # sliding sync tables. + latest_room_bump_stamp = await self.store.get_latest_bump_stamp_for_room( + room_id + ) + + min_to_token_position = to_token.room_key.stream + + # If we can rely on the new sliding sync tables and the `bump_stamp` is + # `None`, just fallback to the membership event position. This can happen + # when we've just joined a remote room and all the events are backfilled. + if ( + # FIXME: The background job check can be removed once we bump + # `SCHEMA_COMPAT_VERSION` and run the foreground update for + # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` + # (tracked by https://github.com/element-hq/synapse/issues/17623) + latest_room_bump_stamp is None + and await self.store.have_finished_sliding_sync_background_jobs() + ): + return None + + # The `bump_stamp` stored in the database might be ahead of our token. Since + # `bump_stamp` is only a `stream_ordering` position, we can't be 100% sure + # that's before the `to_token` in all scenarios. The only scenario we can be + # sure of is if the `bump_stamp` is totally before the minimum position from + # the token. + # + # We don't need to check if the background update has finished, as if the + # returned bump stamp is not None then it must be up to date. + elif ( + latest_room_bump_stamp is not None + and latest_room_bump_stamp < min_to_token_position + ): + if latest_room_bump_stamp > 0: + return latest_room_bump_stamp + else: + return None + + # Otherwise, if it's within or after the `to_token`, we need to find the + # last bump event before the `to_token`. + else: + last_bump_event_result = ( + await self.store.get_last_event_pos_in_room_before_stream_ordering( + room_id, + to_token.room_key, + event_types=SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES, + ) + ) + if last_bump_event_result is not None: + _, new_bump_event_pos = last_bump_event_result + + # If we've just joined a remote room, then the last bump event may + # have been backfilled (and so have a negative stream ordering). + # These negative stream orderings can't sensibly be compared, so + # instead we use the membership event position. + if new_bump_event_pos.stream > 0: + return new_bump_event_pos.stream + + return None + + +def _required_state_changes( + user_id: str, + *, + prev_required_state_map: Mapping[str, AbstractSet[str]], + request_required_state_map: Mapping[str, AbstractSet[str]], + state_deltas: StateMap[str], +) -> Tuple[Optional[Mapping[str, AbstractSet[str]]], StateFilter]: + """Calculates the changes between the required state room config from the + previous requests compared with the current request. + + This does two things. First, it calculates if we need to update the room + config due to changes to required state. Secondly, it works out which state + entries we need to pull from current state and return due to the state entry + now appearing in the required state when it previously wasn't (on top of the + state deltas). + + This function tries to ensure to handle the case where a state entry is + added, removed and then added again to the required state. In that case we + only want to re-send that entry down sync if it has changed. + + Returns: + A 2-tuple of updated required state config (or None if there is no update) + and the state filter to use to fetch extra current state that we need to + return. + """ + if prev_required_state_map == request_required_state_map: + # There has been no change. Return immediately. + return None, StateFilter.none() + + prev_wildcard = prev_required_state_map.get(StateValues.WILDCARD, set()) + request_wildcard = request_required_state_map.get(StateValues.WILDCARD, set()) + + # If we were previously fetching everything ("*", "*"), always update the effective + # room required state config to match the request. And since we we're previously + # already fetching everything, we don't have to fetch anything now that they've + # narrowed. + if StateValues.WILDCARD in prev_wildcard: + return request_required_state_map, StateFilter.none() + + # If a event type wildcard has been added or removed we don't try and do + # anything fancy, and instead always update the effective room required + # state config to match the request. + if request_wildcard - prev_wildcard: + # Some keys were added, so we need to fetch everything + return request_required_state_map, StateFilter.all() + if prev_wildcard - request_wildcard: + # Keys were only removed, so we don't have to fetch everything. + return request_required_state_map, StateFilter.none() + + # Contains updates to the required state map compared with the previous room + # config. This has the same format as `RoomSyncConfig.required_state` + changes: Dict[str, AbstractSet[str]] = {} + + # The set of types/state keys that we need to fetch and return to the + # client. Passed to `StateFilter.from_types(...)` + added: List[Tuple[str, Optional[str]]] = [] + + # Convert the list of state deltas to map from type to state_keys that have + # changed. + changed_types_to_state_keys: Dict[str, Set[str]] = {} + for event_type, state_key in state_deltas: + changed_types_to_state_keys.setdefault(event_type, set()).add(state_key) + + # First we calculate what, if anything, has been *added*. + for event_type in ( + prev_required_state_map.keys() | request_required_state_map.keys() + ): + old_state_keys = prev_required_state_map.get(event_type, set()) + request_state_keys = request_required_state_map.get(event_type, set()) + changed_state_keys = changed_types_to_state_keys.get(event_type, set()) + + if old_state_keys == request_state_keys: + # No change to this type + continue + + if not request_state_keys - old_state_keys: + # Nothing *added*, so we skip. Removals happen below. + continue + + # We only remove state keys from the effective state if they've been + # removed from the request *and* the state has changed. This ensures + # that if a client removes and then re-adds a state key, we only send + # down the associated current state event if its changed (rather than + # sending down the same event twice). + invalidated_state_keys = ( + old_state_keys - request_state_keys + ) & changed_state_keys + + # Figure out which state keys we should remember sending down the connection + inheritable_previous_state_keys = ( + # Retain the previous state_keys that we've sent down before. + # Wildcard and lazy state keys are not sticky from previous requests. + (old_state_keys - {StateValues.WILDCARD, StateValues.LAZY}) + - invalidated_state_keys + ) + + # Always update changes to include the newly added keys (we've expanded the set + # of state keys), use the new requested set with whatever hasn't been + # invalidated from the previous set. + changes[event_type] = request_state_keys | inheritable_previous_state_keys + # Limit the number of state_keys we should remember sending down the connection + # for each (room_id, user_id). We don't want to store and pull out too much data + # in the database. This is a happy-medium between remembering nothing and + # everything. We can avoid sending redundant state down the connection most of + # the time given that most rooms don't have 100 members anyway and it takes a + # while to cycle through 100 members. + # + # Only remember up to (MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER) + if len(changes[event_type]) > MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER: + # Reset back to only the requested state keys + changes[event_type] = request_state_keys + + # Skip if there isn't any room to fill in the rest with previous state keys + if len(request_state_keys) < MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER: + # Fill the rest with previous state_keys. Ideally, we could sort + # these by recency but it's just a set so just pick an arbitrary + # subset (good enough). + changes[event_type] = changes[event_type] | set( + itertools.islice( + inheritable_previous_state_keys, + # Just taking the difference isn't perfect as there could be + # overlap in the keys between the requested and previous but we + # will decide to just take the easy route for now and avoid + # additional set operations to figure it out. + MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER + - len(request_state_keys), + ) + ) + + if StateValues.WILDCARD in old_state_keys: + # We were previously fetching everything for this type, so we don't need to + # fetch anything new. + continue + + # Record the new state keys to fetch for this type. + if StateValues.WILDCARD in request_state_keys: + # If we have added a wildcard then we always just fetch everything. + added.append((event_type, None)) + else: + for state_key in request_state_keys - old_state_keys: + if state_key == StateValues.ME: + added.append((event_type, user_id)) + elif state_key == StateValues.LAZY: + # We handle lazy loading separately (outside this function), + # so don't need to explicitly add anything here. + # + # LAZY values should also be ignore for event types that are + # not membership. + pass + else: + added.append((event_type, state_key)) + + added_state_filter = StateFilter.from_types(added) + + # Figure out what changes we need to apply to the effective required state + # config. + for event_type, changed_state_keys in changed_types_to_state_keys.items(): + old_state_keys = prev_required_state_map.get(event_type, set()) + request_state_keys = request_required_state_map.get(event_type, set()) + + if old_state_keys == request_state_keys: + # No change. + continue + + # If we see the `user_id` as a state_key, also add "$ME" to the list of state + # that has changed to account for people requesting `required_state` with `$ME` + # or their user ID. + if user_id in changed_state_keys: + changed_state_keys.add(StateValues.ME) + + # We only remove state keys from the effective state if they've been + # removed from the request *and* the state has changed. This ensures + # that if a client removes and then re-adds a state key, we only send + # down the associated current state event if its changed (rather than + # sending down the same event twice). + invalidated_state_keys = ( + old_state_keys - request_state_keys + ) & changed_state_keys + + # We've expanded the set of state keys, ... (already handled above) + if request_state_keys - old_state_keys: + continue + + old_state_key_wildcard = StateValues.WILDCARD in old_state_keys + request_state_key_wildcard = StateValues.WILDCARD in request_state_keys + + if old_state_key_wildcard != request_state_key_wildcard: + # If a state_key wildcard has been added or removed, we always update the + # effective room required state config to match the request. + changes[event_type] = request_state_keys + continue + + if event_type == EventTypes.Member: + old_state_key_lazy = StateValues.LAZY in old_state_keys + request_state_key_lazy = StateValues.LAZY in request_state_keys + + if old_state_key_lazy != request_state_key_lazy: + # If a "$LAZY" has been added or removed we always update the effective room + # required state config to match the request. + changes[event_type] = request_state_keys + continue + + # At this point there are no wildcards and no additions to the set of + # state keys requested, only deletions. + # + # We only remove state keys from the effective state if they've been + # removed from the request *and* the state has changed. This ensures + # that if a client removes and then re-adds a state key, we only send + # down the associated current state event if its changed (rather than + # sending down the same event twice). + if invalidated_state_keys: + changes[event_type] = old_state_keys - invalidated_state_keys + + if changes: + # Update the required state config based on the changes. + new_required_state_map = dict(prev_required_state_map) + for event_type, state_keys in changes.items(): + if state_keys: + new_required_state_map[event_type] = state_keys + else: + # Remove entries with empty state keys. + new_required_state_map.pop(event_type, None) + + return new_required_state_map, added_state_filter + else: + return None, added_state_filter diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py new file mode 100644
index 0000000000..077887ec32 --- /dev/null +++ b/synapse/handlers/sliding_sync/extensions.py
@@ -0,0 +1,879 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2023 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +import itertools +import logging +from typing import ( + TYPE_CHECKING, + AbstractSet, + ChainMap, + Dict, + Mapping, + MutableMapping, + Optional, + Sequence, + Set, + cast, +) + +from typing_extensions import assert_never + +from synapse.api.constants import AccountDataTypes, EduTypes +from synapse.handlers.receipts import ReceiptEventSource +from synapse.logging.opentracing import trace +from synapse.storage.databases.main.receipts import ReceiptInRoom +from synapse.types import ( + DeviceListUpdates, + JsonMapping, + MultiWriterStreamToken, + SlidingSyncStreamToken, + StrCollection, + StreamToken, +) +from synapse.types.handlers.sliding_sync import ( + HaveSentRoomFlag, + MutablePerConnectionState, + OperationType, + PerConnectionState, + SlidingSyncConfig, + SlidingSyncResult, +) +from synapse.util.async_helpers import ( + concurrently_execute, + gather_optional_coroutines, +) + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class SlidingSyncExtensionHandler: + """Handles the extensions to sliding sync.""" + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastores().main + self.event_sources = hs.get_event_sources() + self.device_handler = hs.get_device_handler() + self.push_rules_handler = hs.get_push_rules_handler() + + @trace + async def get_extensions_response( + self, + sync_config: SlidingSyncConfig, + previous_connection_state: "PerConnectionState", + new_connection_state: "MutablePerConnectionState", + actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList], + actual_room_ids: Set[str], + actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult], + to_token: StreamToken, + from_token: Optional[SlidingSyncStreamToken], + ) -> SlidingSyncResult.Extensions: + """Handle extension requests. + + Args: + sync_config: Sync configuration + new_connection_state: Snapshot of the current per-connection state + new_per_connection_state: A mutable copy of the per-connection + state, used to record updates to the state during this request. + actual_lists: Sliding window API. A map of list key to list results in the + Sliding Sync response. + actual_room_ids: The actual room IDs in the the Sliding Sync response. + actual_room_response_map: A map of room ID to room results in the the + Sliding Sync response. + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. + """ + + if sync_config.extensions is None: + return SlidingSyncResult.Extensions() + + to_device_coro = None + if sync_config.extensions.to_device is not None: + to_device_coro = self.get_to_device_extension_response( + sync_config=sync_config, + to_device_request=sync_config.extensions.to_device, + to_token=to_token, + ) + + e2ee_coro = None + if sync_config.extensions.e2ee is not None: + e2ee_coro = self.get_e2ee_extension_response( + sync_config=sync_config, + e2ee_request=sync_config.extensions.e2ee, + to_token=to_token, + from_token=from_token, + ) + + account_data_coro = None + if sync_config.extensions.account_data is not None: + account_data_coro = self.get_account_data_extension_response( + sync_config=sync_config, + previous_connection_state=previous_connection_state, + new_connection_state=new_connection_state, + actual_lists=actual_lists, + actual_room_ids=actual_room_ids, + account_data_request=sync_config.extensions.account_data, + to_token=to_token, + from_token=from_token, + ) + + receipts_coro = None + if sync_config.extensions.receipts is not None: + receipts_coro = self.get_receipts_extension_response( + sync_config=sync_config, + previous_connection_state=previous_connection_state, + new_connection_state=new_connection_state, + actual_lists=actual_lists, + actual_room_ids=actual_room_ids, + actual_room_response_map=actual_room_response_map, + receipts_request=sync_config.extensions.receipts, + to_token=to_token, + from_token=from_token, + ) + + typing_coro = None + if sync_config.extensions.typing is not None: + typing_coro = self.get_typing_extension_response( + sync_config=sync_config, + actual_lists=actual_lists, + actual_room_ids=actual_room_ids, + actual_room_response_map=actual_room_response_map, + typing_request=sync_config.extensions.typing, + to_token=to_token, + from_token=from_token, + ) + + ( + to_device_response, + e2ee_response, + account_data_response, + receipts_response, + typing_response, + ) = await gather_optional_coroutines( + to_device_coro, + e2ee_coro, + account_data_coro, + receipts_coro, + typing_coro, + ) + + return SlidingSyncResult.Extensions( + to_device=to_device_response, + e2ee=e2ee_response, + account_data=account_data_response, + receipts=receipts_response, + typing=typing_response, + ) + + def find_relevant_room_ids_for_extension( + self, + requested_lists: Optional[StrCollection], + requested_room_ids: Optional[StrCollection], + actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList], + actual_room_ids: AbstractSet[str], + ) -> Set[str]: + """ + Handle the reserved `lists`/`rooms` keys for extensions. Extensions should only + return results for rooms in the Sliding Sync response. This matches up the + requested rooms/lists with the actual lists/rooms in the Sliding Sync response. + + {"lists": []} // Do not process any lists. + {"lists": ["rooms", "dms"]} // Process only a subset of lists. + {"lists": ["*"]} // Process all lists defined in the Sliding Window API. (This is the default.) + + {"rooms": []} // Do not process any specific rooms. + {"rooms": ["!a:b", "!c:d"]} // Process only a subset of room subscriptions. + {"rooms": ["*"]} // Process all room subscriptions defined in the Room Subscription API. (This is the default.) + + Args: + requested_lists: The `lists` from the extension request. + requested_room_ids: The `rooms` from the extension request. + actual_lists: The actual lists from the Sliding Sync response. + actual_room_ids: The actual room subscriptions from the Sliding Sync request. + """ + + # We only want to include account data for rooms that are already in the sliding + # sync response AND that were requested in the account data request. + relevant_room_ids: Set[str] = set() + + # See what rooms from the room subscriptions we should get account data for + if requested_room_ids is not None: + for room_id in requested_room_ids: + # A wildcard means we process all rooms from the room subscriptions + if room_id == "*": + relevant_room_ids.update(actual_room_ids) + break + + if room_id in actual_room_ids: + relevant_room_ids.add(room_id) + + # See what rooms from the sliding window lists we should get account data for + if requested_lists is not None: + for list_key in requested_lists: + # Just some typing because we share the variable name in multiple places + actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None + + # A wildcard means we process rooms from all lists + if list_key == "*": + for actual_list in actual_lists.values(): + # We only expect a single SYNC operation for any list + assert len(actual_list.ops) == 1 + sync_op = actual_list.ops[0] + assert sync_op.op == OperationType.SYNC + + relevant_room_ids.update(sync_op.room_ids) + + break + + actual_list = actual_lists.get(list_key) + if actual_list is not None: + # We only expect a single SYNC operation for any list + assert len(actual_list.ops) == 1 + sync_op = actual_list.ops[0] + assert sync_op.op == OperationType.SYNC + + relevant_room_ids.update(sync_op.room_ids) + + return relevant_room_ids + + @trace + async def get_to_device_extension_response( + self, + sync_config: SlidingSyncConfig, + to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension, + to_token: StreamToken, + ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]: + """Handle to-device extension (MSC3885) + + Args: + sync_config: Sync configuration + to_device_request: The to-device extension from the request + to_token: The point in the stream to sync up to. + """ + user_id = sync_config.user.to_string() + device_id = sync_config.requester.device_id + + # Skip if the extension is not enabled + if not to_device_request.enabled: + return None + + # Check that this request has a valid device ID (not all requests have + # to belong to a device, and so device_id is None) + if device_id is None: + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=f"{to_token.to_device_key}", + events=[], + ) + + since_stream_id = 0 + if to_device_request.since is not None: + # We've already validated this is an int. + since_stream_id = int(to_device_request.since) + + if to_token.to_device_key < since_stream_id: + # The since token is ahead of our current token, so we return an + # empty response. + logger.warning( + "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r", + since_stream_id, + to_token.to_device_key, + ) + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=to_device_request.since, + events=[], + ) + + # Delete everything before the given since token, as we know the + # device must have received them. + deleted = await self.store.delete_messages_for_device( + user_id=user_id, + device_id=device_id, + up_to_stream_id=since_stream_id, + ) + + logger.debug( + "Deleted %d to-device messages up to %d for %s", + deleted, + since_stream_id, + user_id, + ) + + messages, stream_id = await self.store.get_messages_for_device( + user_id=user_id, + device_id=device_id, + from_stream_id=since_stream_id, + to_stream_id=to_token.to_device_key, + limit=min(to_device_request.limit, 100), # Limit to at most 100 events + ) + + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=f"{stream_id}", + events=messages, + ) + + @trace + async def get_e2ee_extension_response( + self, + sync_config: SlidingSyncConfig, + e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension, + to_token: StreamToken, + from_token: Optional[SlidingSyncStreamToken], + ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]: + """Handle E2EE device extension (MSC3884) + + Args: + sync_config: Sync configuration + e2ee_request: The e2ee extension from the request + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. + """ + user_id = sync_config.user.to_string() + device_id = sync_config.requester.device_id + + # Skip if the extension is not enabled + if not e2ee_request.enabled: + return None + + device_list_updates: Optional[DeviceListUpdates] = None + if from_token is not None: + # TODO: This should take into account the `from_token` and `to_token` + device_list_updates = await self.device_handler.get_user_ids_changed( + user_id=user_id, + from_token=from_token.stream_token, + ) + + device_one_time_keys_count: Mapping[str, int] = {} + device_unused_fallback_key_types: Sequence[str] = [] + if device_id: + # TODO: We should have a way to let clients differentiate between the states of: + # * no change in OTK count since the provided since token + # * the server has zero OTKs left for this device + # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 + device_one_time_keys_count = await self.store.count_e2e_one_time_keys( + user_id, device_id + ) + device_unused_fallback_key_types = ( + await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) + ) + + return SlidingSyncResult.Extensions.E2eeExtension( + device_list_updates=device_list_updates, + device_one_time_keys_count=device_one_time_keys_count, + device_unused_fallback_key_types=device_unused_fallback_key_types, + ) + + @trace + async def get_account_data_extension_response( + self, + sync_config: SlidingSyncConfig, + previous_connection_state: "PerConnectionState", + new_connection_state: "MutablePerConnectionState", + actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList], + actual_room_ids: Set[str], + account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension, + to_token: StreamToken, + from_token: Optional[SlidingSyncStreamToken], + ) -> Optional[SlidingSyncResult.Extensions.AccountDataExtension]: + """Handle Account Data extension (MSC3959) + + Args: + sync_config: Sync configuration + actual_lists: Sliding window API. A map of list key to list results in the + Sliding Sync response. + actual_room_ids: The actual room IDs in the the Sliding Sync response. + account_data_request: The account_data extension from the request + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. + """ + user_id = sync_config.user.to_string() + + # Skip if the extension is not enabled + if not account_data_request.enabled: + return None + + global_account_data_map: Mapping[str, JsonMapping] = {} + if from_token is not None: + # TODO: This should take into account the `from_token` and `to_token` + global_account_data_map = ( + await self.store.get_updated_global_account_data_for_user( + user_id, from_token.stream_token.account_data_key + ) + ) + + # TODO: This should take into account the `from_token` and `to_token` + have_push_rules_changed = await self.store.have_push_rules_changed_for_user( + user_id, from_token.stream_token.push_rules_key + ) + if have_push_rules_changed: + # TODO: This should take into account the `from_token` and `to_token` + global_account_data_map[ + AccountDataTypes.PUSH_RULES + ] = await self.push_rules_handler.push_rules_for_user(sync_config.user) + else: + # TODO: This should take into account the `to_token` + immutable_global_account_data_map = ( + await self.store.get_global_account_data_for_user(user_id) + ) + + # Use a `ChainMap` to avoid copying the immutable data from the cache + global_account_data_map = ChainMap( + { + # TODO: This should take into account the `to_token` + AccountDataTypes.PUSH_RULES: await self.push_rules_handler.push_rules_for_user( + sync_config.user + ) + }, + # Cast is safe because `ChainMap` only mutates the top-most map, + # see https://github.com/python/typeshed/issues/8430 + cast( + MutableMapping[str, JsonMapping], immutable_global_account_data_map + ), + ) + + # Fetch room account data + # + account_data_by_room_map: MutableMapping[str, Mapping[str, JsonMapping]] = {} + relevant_room_ids = self.find_relevant_room_ids_for_extension( + requested_lists=account_data_request.lists, + requested_room_ids=account_data_request.rooms, + actual_lists=actual_lists, + actual_room_ids=actual_room_ids, + ) + if len(relevant_room_ids) > 0: + # We need to handle the different cases depending on if we have sent + # down account data previously or not, so we split the relevant + # rooms up into different collections based on status. + live_rooms = set() + previously_rooms: Dict[str, int] = {} + initial_rooms = set() + + for room_id in relevant_room_ids: + if not from_token: + initial_rooms.add(room_id) + continue + + room_status = previous_connection_state.account_data.have_sent_room( + room_id + ) + if room_status.status == HaveSentRoomFlag.LIVE: + live_rooms.add(room_id) + elif room_status.status == HaveSentRoomFlag.PREVIOUSLY: + assert room_status.last_token is not None + previously_rooms[room_id] = room_status.last_token + elif room_status.status == HaveSentRoomFlag.NEVER: + initial_rooms.add(room_id) + else: + assert_never(room_status.status) + + # We fetch all room account data since the from_token. This is so + # that we can record which rooms have updates that haven't been sent + # down. + # + # Mapping from room_id to mapping of `type` to `content` of room account + # data events. + all_updates_since_the_from_token: Mapping[ + str, Mapping[str, JsonMapping] + ] = {} + if from_token is not None: + # TODO: This should take into account the `from_token` and `to_token` + all_updates_since_the_from_token = ( + await self.store.get_updated_room_account_data_for_user( + user_id, from_token.stream_token.account_data_key + ) + ) + + # Add room tags + # + # TODO: This should take into account the `from_token` and `to_token` + tags_by_room = await self.store.get_updated_tags( + user_id, from_token.stream_token.account_data_key + ) + for room_id, tags in tags_by_room.items(): + all_updates_since_the_from_token.setdefault(room_id, {})[ + AccountDataTypes.TAG + ] = {"tags": tags} + + # For live rooms we just get the updates from `all_updates_since_the_from_token` + if live_rooms: + for room_id in all_updates_since_the_from_token.keys() & live_rooms: + account_data_by_room_map[room_id] = ( + all_updates_since_the_from_token[room_id] + ) + + # For previously and initial rooms we query each room individually. + if previously_rooms or initial_rooms: + + async def handle_previously(room_id: str) -> None: + # Either get updates or all account data in the room + # depending on if the room state is PREVIOUSLY or NEVER. + previous_token = previously_rooms.get(room_id) + if previous_token is not None: + room_account_data = await ( + self.store.get_updated_room_account_data_for_user_for_room( + user_id=user_id, + room_id=room_id, + from_stream_id=previous_token, + to_stream_id=to_token.account_data_key, + ) + ) + + # Add room tags + changed = await self.store.has_tags_changed_for_room( + user_id=user_id, + room_id=room_id, + from_stream_id=previous_token, + to_stream_id=to_token.account_data_key, + ) + if changed: + # XXX: Ideally, this should take into account the `to_token` + # and return the set of tags at that time but we don't track + # changes to tags so we just have to return all tags for the + # room. + immutable_tag_map = await self.store.get_tags_for_room( + user_id, room_id + ) + room_account_data[AccountDataTypes.TAG] = { + "tags": immutable_tag_map + } + + # Only add an entry if there were any updates. + if room_account_data: + account_data_by_room_map[room_id] = room_account_data + else: + # TODO: This should take into account the `to_token` + immutable_room_account_data = ( + await self.store.get_account_data_for_room(user_id, room_id) + ) + + # Add room tags + # + # XXX: Ideally, this should take into account the `to_token` + # and return the set of tags at that time but we don't track + # changes to tags so we just have to return all tags for the + # room. + immutable_tag_map = await self.store.get_tags_for_room( + user_id, room_id + ) + + account_data_by_room_map[room_id] = ChainMap( + {AccountDataTypes.TAG: {"tags": immutable_tag_map}} + if immutable_tag_map + else {}, + # Cast is safe because `ChainMap` only mutates the top-most map, + # see https://github.com/python/typeshed/issues/8430 + cast( + MutableMapping[str, JsonMapping], + immutable_room_account_data, + ), + ) + + # We handle these rooms concurrently to speed it up. + await concurrently_execute( + handle_previously, + previously_rooms.keys() | initial_rooms, + limit=20, + ) + + # Now record which rooms are now up to data, and which rooms have + # pending updates to send. + new_connection_state.account_data.record_sent_rooms(previously_rooms.keys()) + new_connection_state.account_data.record_sent_rooms(initial_rooms) + missing_updates = ( + all_updates_since_the_from_token.keys() - relevant_room_ids + ) + if missing_updates: + # If we have missing updates then we must have had a from_token. + assert from_token is not None + + new_connection_state.account_data.record_unsent_rooms( + missing_updates, from_token.stream_token.account_data_key + ) + + return SlidingSyncResult.Extensions.AccountDataExtension( + global_account_data_map=global_account_data_map, + account_data_by_room_map=account_data_by_room_map, + ) + + @trace + async def get_receipts_extension_response( + self, + sync_config: SlidingSyncConfig, + previous_connection_state: "PerConnectionState", + new_connection_state: "MutablePerConnectionState", + actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList], + actual_room_ids: Set[str], + actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult], + receipts_request: SlidingSyncConfig.Extensions.ReceiptsExtension, + to_token: StreamToken, + from_token: Optional[SlidingSyncStreamToken], + ) -> Optional[SlidingSyncResult.Extensions.ReceiptsExtension]: + """Handle Receipts extension (MSC3960) + + Args: + sync_config: Sync configuration + previous_connection_state: The current per-connection state + new_connection_state: A mutable copy of the per-connection + state, used to record updates to the state. + actual_lists: Sliding window API. A map of list key to list results in the + Sliding Sync response. + actual_room_ids: The actual room IDs in the the Sliding Sync response. + actual_room_response_map: A map of room ID to room results in the the + Sliding Sync response. + account_data_request: The account_data extension from the request + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. + """ + # Skip if the extension is not enabled + if not receipts_request.enabled: + return None + + relevant_room_ids = self.find_relevant_room_ids_for_extension( + requested_lists=receipts_request.lists, + requested_room_ids=receipts_request.rooms, + actual_lists=actual_lists, + actual_room_ids=actual_room_ids, + ) + + room_id_to_receipt_map: Dict[str, JsonMapping] = {} + if len(relevant_room_ids) > 0: + # We need to handle the different cases depending on if we have sent + # down receipts previously or not, so we split the relevant rooms + # up into different collections based on status. + live_rooms = set() + previously_rooms: Dict[str, MultiWriterStreamToken] = {} + initial_rooms = set() + + for room_id in relevant_room_ids: + if not from_token: + initial_rooms.add(room_id) + continue + + # If we're sending down the room from scratch again for some + # reason, we should always resend the receipts as well + # (regardless of if we've sent them down before). This is to + # mimic the behaviour of what happens on initial sync, where you + # get a chunk of timeline with all of the corresponding receipts + # for the events in the timeline. + # + # We also resend down receipts when we "expand" the timeline, + # (see the "XXX: Odd behavior" in + # `synapse.handlers.sliding_sync`). + room_result = actual_room_response_map.get(room_id) + if room_result is not None: + if room_result.initial or room_result.unstable_expanded_timeline: + initial_rooms.add(room_id) + continue + + room_status = previous_connection_state.receipts.have_sent_room(room_id) + if room_status.status == HaveSentRoomFlag.LIVE: + live_rooms.add(room_id) + elif room_status.status == HaveSentRoomFlag.PREVIOUSLY: + assert room_status.last_token is not None + previously_rooms[room_id] = room_status.last_token + elif room_status.status == HaveSentRoomFlag.NEVER: + initial_rooms.add(room_id) + else: + assert_never(room_status.status) + + # The set of receipts that we fetched. Private receipts need to be + # filtered out before returning. + fetched_receipts = [] + + # For live rooms we just fetch all receipts in those rooms since the + # `since` token. + if live_rooms: + assert from_token is not None + receipts = await self.store.get_linearized_receipts_for_rooms( + room_ids=live_rooms, + from_key=from_token.stream_token.receipt_key, + to_key=to_token.receipt_key, + ) + fetched_receipts.extend(receipts) + + # For rooms we've previously sent down, but aren't up to date, we + # need to use the from token from the room status. + if previously_rooms: + # Fetch any missing rooms concurrently. + + async def handle_previously_room(room_id: str) -> None: + receipt_token = previously_rooms[room_id] + # TODO: Limit the number of receipts we're about to send down + # for the room, if its too many we should TODO + previously_receipts = ( + await self.store.get_linearized_receipts_for_room( + room_id=room_id, + from_key=receipt_token, + to_key=to_token.receipt_key, + ) + ) + fetched_receipts.extend(previously_receipts) + + await concurrently_execute( + handle_previously_room, previously_rooms.keys(), 20 + ) + + if initial_rooms: + # We also always send down receipts for the current user. + user_receipts = ( + await self.store.get_linearized_receipts_for_user_in_rooms( + user_id=sync_config.user.to_string(), + room_ids=initial_rooms, + to_key=to_token.receipt_key, + ) + ) + + # For rooms we haven't previously sent down, we could send all receipts + # from that room but we only want to include receipts for events + # in the timeline to avoid bloating and blowing up the sync response + # as the number of users in the room increases. (this behavior is part of the spec) + initial_rooms_and_event_ids = [ + (room_id, event.event_id) + for room_id in initial_rooms + if room_id in actual_room_response_map + for event in actual_room_response_map[room_id].timeline_events + ] + initial_receipts = await self.store.get_linearized_receipts_for_events( + room_and_event_ids=initial_rooms_and_event_ids, + ) + + # Combine the receipts for a room and add them to + # `fetched_receipts` + for room_id in initial_receipts.keys() | user_receipts.keys(): + receipt_content = ReceiptInRoom.merge_to_content( + list( + itertools.chain( + initial_receipts.get(room_id, []), + user_receipts.get(room_id, []), + ) + ) + ) + + fetched_receipts.append( + { + "room_id": room_id, + "type": EduTypes.RECEIPT, + "content": receipt_content, + } + ) + + fetched_receipts = ReceiptEventSource.filter_out_private_receipts( + fetched_receipts, sync_config.user.to_string() + ) + + for receipt in fetched_receipts: + # These fields should exist for every receipt + room_id = receipt["room_id"] + type = receipt["type"] + content = receipt["content"] + + room_id_to_receipt_map[room_id] = {"type": type, "content": content} + + # Update the per-connection state to track which rooms we have sent + # all the receipts for. + new_connection_state.receipts.record_sent_rooms(previously_rooms.keys()) + new_connection_state.receipts.record_sent_rooms(initial_rooms) + + if from_token: + # Now find the set of rooms that may have receipts that we're not sending + # down. We only need to check rooms that we have previously returned + # receipts for (in `previous_connection_state`) because we only care about + # updating `LIVE` rooms to `PREVIOUSLY`. The `PREVIOUSLY` rooms will just + # stay pointing at their previous position so we don't need to waste time + # checking those and since we default to `NEVER`, rooms that were `NEVER` + # sent before don't need to be recorded as we'll handle them correctly when + # they come into range for the first time. + rooms_no_receipts = [ + room_id + for room_id, room_status in previous_connection_state.receipts._statuses.items() + if room_status.status == HaveSentRoomFlag.LIVE + and room_id not in relevant_room_ids + ] + changed_rooms = await self.store.get_rooms_with_receipts_between( + rooms_no_receipts, + from_key=from_token.stream_token.receipt_key, + to_key=to_token.receipt_key, + ) + new_connection_state.receipts.record_unsent_rooms( + changed_rooms, from_token.stream_token.receipt_key + ) + + return SlidingSyncResult.Extensions.ReceiptsExtension( + room_id_to_receipt_map=room_id_to_receipt_map, + ) + + async def get_typing_extension_response( + self, + sync_config: SlidingSyncConfig, + actual_lists: Mapping[str, SlidingSyncResult.SlidingWindowList], + actual_room_ids: Set[str], + actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult], + typing_request: SlidingSyncConfig.Extensions.TypingExtension, + to_token: StreamToken, + from_token: Optional[SlidingSyncStreamToken], + ) -> Optional[SlidingSyncResult.Extensions.TypingExtension]: + """Handle Typing Notification extension (MSC3961) + + Args: + sync_config: Sync configuration + actual_lists: Sliding window API. A map of list key to list results in the + Sliding Sync response. + actual_room_ids: The actual room IDs in the the Sliding Sync response. + actual_room_response_map: A map of room ID to room results in the the + Sliding Sync response. + account_data_request: The account_data extension from the request + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. + """ + # Skip if the extension is not enabled + if not typing_request.enabled: + return None + + relevant_room_ids = self.find_relevant_room_ids_for_extension( + requested_lists=typing_request.lists, + requested_room_ids=typing_request.rooms, + actual_lists=actual_lists, + actual_room_ids=actual_room_ids, + ) + + room_id_to_typing_map: Dict[str, JsonMapping] = {} + if len(relevant_room_ids) > 0: + # Note: We don't need to take connection tracking into account for typing + # notifications because they'll get anything still relevant and hasn't timed + # out when the room comes into range. We consider the gap where the room + # fell out of range, as long enough for any typing notifications to have + # timed out (it's not worth the 30 seconds of data we may have missed). + typing_source = self.event_sources.sources.typing + typing_notifications, _ = await typing_source.get_new_events( + user=sync_config.user, + from_key=(from_token.stream_token.typing_key if from_token else 0), + to_key=to_token.typing_key, + # This is a dummy value and isn't used in the function + limit=0, + room_ids=relevant_room_ids, + is_guest=False, + ) + + for typing_notification in typing_notifications: + # These fields should exist for every typing notification + room_id = typing_notification["room_id"] + type = typing_notification["type"] + content = typing_notification["content"] + + room_id_to_typing_map[room_id] = {"type": type, "content": content} + + return SlidingSyncResult.Extensions.TypingExtension( + room_id_to_typing_map=room_id_to_typing_map, + ) diff --git a/synapse/handlers/sliding_sync/room_lists.py b/synapse/handlers/sliding_sync/room_lists.py new file mode 100644
index 0000000000..13e69f18a0 --- /dev/null +++ b/synapse/handlers/sliding_sync/room_lists.py
@@ -0,0 +1,2304 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2023 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + + +import enum +import logging +from itertools import chain +from typing import ( + TYPE_CHECKING, + AbstractSet, + Dict, + List, + Literal, + Mapping, + Optional, + Set, + Tuple, + Union, + cast, +) + +import attr +from immutabledict import immutabledict +from typing_extensions import assert_never + +from synapse.api.constants import ( + AccountDataTypes, + EventContentFields, + EventTypes, + Membership, +) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import StrippedStateEvent +from synapse.events.utils import parse_stripped_state_event +from synapse.logging.opentracing import start_active_span, trace +from synapse.storage.databases.main.state import ( + ROOM_UNKNOWN_SENTINEL, + Sentinel as StateSentinel, +) +from synapse.storage.databases.main.stream import CurrentStateDeltaMembership +from synapse.storage.invite_rule import InviteRule +from synapse.storage.roommember import ( + RoomsForUser, + RoomsForUserSlidingSync, + RoomsForUserStateReset, +) +from synapse.types import ( + MutableStateMap, + RoomStreamToken, + StateMap, + StrCollection, + StreamKeyType, + StreamToken, + UserID, +) +from synapse.types.handlers.sliding_sync import ( + HaveSentRoomFlag, + OperationType, + PerConnectionState, + RoomSyncConfig, + SlidingSyncConfig, + SlidingSyncResult, +) +from synapse.types.state import StateFilter + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +class Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup and subsequent type narrowing. + UNSET_SENTINEL = object() + + +# Helper definition for the types that we might return. We do this to avoid +# copying data between types (which can be expensive for many rooms). +RoomsForUserType = Union[RoomsForUserStateReset, RoomsForUser, RoomsForUserSlidingSync] + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class SlidingSyncInterestedRooms: + """The set of rooms and metadata a client is interested in based on their + sliding sync request. + + Returned by `compute_interested_rooms`. + + Attributes: + lists: A mapping from list name to the list result for the response + relevant_room_map: A map from rooms that match the sync request to + their room sync config. + relevant_rooms_to_send_map: Subset of `relevant_room_map` that + includes the rooms that *may* have relevant updates. Rooms not + in this map will definitely not have room updates (though + extensions may have updates in these rooms). + newly_joined_rooms: The set of rooms that were joined in the token range + and the user is still joined to at the end of this range. + newly_left_rooms: The set of rooms that we left in the token range + and are still "leave" at the end of this range. + dm_room_ids: The set of rooms the user consider as direct-message (DM) rooms + """ + + lists: Mapping[str, SlidingSyncResult.SlidingWindowList] + relevant_room_map: Mapping[str, RoomSyncConfig] + relevant_rooms_to_send_map: Mapping[str, RoomSyncConfig] + all_rooms: Set[str] + room_membership_for_user_map: Mapping[str, RoomsForUserType] + + newly_joined_rooms: AbstractSet[str] + newly_left_rooms: AbstractSet[str] + dm_room_ids: AbstractSet[str] + + @staticmethod + def empty() -> "SlidingSyncInterestedRooms": + return SlidingSyncInterestedRooms( + lists={}, + relevant_room_map={}, + relevant_rooms_to_send_map={}, + all_rooms=set(), + room_membership_for_user_map={}, + newly_joined_rooms=set(), + newly_left_rooms=set(), + dm_room_ids=set(), + ) + + +def filter_membership_for_sync( + *, + user_id: str, + room_membership_for_user: RoomsForUserType, + newly_left: bool, +) -> bool: + """ + Returns True if the membership event should be included in the sync response, + otherwise False. + + Attributes: + user_id: The user ID that the membership applies to + room_membership_for_user: Membership information for the user in the room + """ + + membership = room_membership_for_user.membership + sender = room_membership_for_user.sender + + # We want to allow everything except rooms the user has left unless `newly_left` + # because we want everything that's *still* relevant to the user. We include + # `newly_left` rooms because the last event that the user should see is their own + # leave event. + # + # A leave != kick. This logic includes kicks (leave events where the sender is not + # the same user). + # + # When `sender=None`, it means that a state reset happened that removed the user + # from the room without a corresponding leave event. We can just remove the rooms + # since they are no longer relevant to the user but will still appear if they are + # `newly_left`. + return ( + # Anything except leave events + membership != Membership.LEAVE + # Unless... + or newly_left + # Allow kicks + or (membership == Membership.LEAVE and sender not in (user_id, None)) + ) + + +class SlidingSyncRoomLists: + """Handles calculating the room lists from sliding sync requests""" + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastores().main + self.storage_controllers = hs.get_storage_controllers() + self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync + self.is_mine_id = hs.is_mine_id + + async def compute_interested_rooms( + self, + sync_config: SlidingSyncConfig, + previous_connection_state: "PerConnectionState", + to_token: StreamToken, + from_token: Optional[StreamToken], + ) -> SlidingSyncInterestedRooms: + """Fetch the set of rooms that match the request""" + has_lists = sync_config.lists is not None and len(sync_config.lists) > 0 + has_room_subscriptions = ( + sync_config.room_subscriptions is not None + and len(sync_config.room_subscriptions) > 0 + ) + + if not has_lists and not has_room_subscriptions: + return SlidingSyncInterestedRooms.empty() + + if await self.store.have_finished_sliding_sync_background_jobs(): + return await self._compute_interested_rooms_new_tables( + sync_config=sync_config, + previous_connection_state=previous_connection_state, + to_token=to_token, + from_token=from_token, + ) + else: + # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the + # foreground update for + # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by + # https://github.com/element-hq/synapse/issues/17623) + return await self._compute_interested_rooms_fallback( + sync_config=sync_config, + previous_connection_state=previous_connection_state, + to_token=to_token, + from_token=from_token, + ) + + @trace + async def _compute_interested_rooms_new_tables( + self, + sync_config: SlidingSyncConfig, + previous_connection_state: "PerConnectionState", + to_token: StreamToken, + from_token: Optional[StreamToken], + ) -> SlidingSyncInterestedRooms: + """Implementation of `compute_interested_rooms` using new sliding sync db tables.""" + user_id = sync_config.user.to_string() + + # Assemble sliding window lists + lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {} + # Keep track of the rooms that we can display and need to fetch more info about + relevant_room_map: Dict[str, RoomSyncConfig] = {} + # The set of room IDs of all rooms that could appear in any list. These + # include rooms that are outside the list ranges. + all_rooms: Set[str] = set() + + # Note: this won't include rooms the user has left themselves. We add back + # `newly_left` rooms below. This is more efficient than fetching all rooms and + # then filtering out the old left rooms. + room_membership_for_user_map = ( + await self.store.get_sliding_sync_rooms_for_user_from_membership_snapshots( + user_id + ) + ) + # To play nice with the rewind logic below, we need to go fetch the rooms the + # user has left themselves but only if it changed after the `to_token`. + # + # If a leave happens *after* the token range, we may have still been joined (or + # any non-self-leave which is relevant to sync) to the room before so we need to + # include it in the list of potentially relevant rooms and apply our rewind + # logic (outside of this function) to see if it's actually relevant. + # + # We do this separately from + # `get_sliding_sync_rooms_for_user_from_membership_snapshots` as those results + # are cached and the `to_token` isn't very cache friendly (people are constantly + # requesting with new tokens) so we separate it out here. + self_leave_room_membership_for_user_map = ( + await self.store.get_sliding_sync_self_leave_rooms_after_to_token( + user_id, to_token + ) + ) + if self_leave_room_membership_for_user_map: + # FIXME: It would be nice to avoid this copy but since + # `get_sliding_sync_rooms_for_user_from_membership_snapshots` is cached, it + # can't return a mutable value like a `dict`. We make the copy to get a + # mutable dict that we can change. We try to only make a copy when necessary + # (if we actually need to change something) as in most cases, the logic + # doesn't need to run. + room_membership_for_user_map = dict(room_membership_for_user_map) + room_membership_for_user_map.update(self_leave_room_membership_for_user_map) + + # Remove invites from ignored users + ignored_users = await self.store.ignored_users(user_id) + invite_config = await self.store.get_invite_config_for_user(user_id) + if ignored_users: + # FIXME: It would be nice to avoid this copy but since + # `get_sliding_sync_rooms_for_user_from_membership_snapshots` is cached, it + # can't return a mutable value like a `dict`. We make the copy to get a + # mutable dict that we can change. We try to only make a copy when necessary + # (if we actually need to change something) as in most cases, the logic + # doesn't need to run. + room_membership_for_user_map = dict(room_membership_for_user_map) + # Make a copy so we don't run into an error: `dictionary changed size during + # iteration`, when we remove items + for room_id in list(room_membership_for_user_map.keys()): + room_for_user_sliding_sync = room_membership_for_user_map[room_id] + if ( + room_for_user_sliding_sync.membership == Membership.INVITE + and room_for_user_sliding_sync.sender + and ( + room_for_user_sliding_sync.sender in ignored_users + or invite_config.get_invite_rule( + room_for_user_sliding_sync.sender + ) + == InviteRule.IGNORE + ) + ): + room_membership_for_user_map.pop(room_id, None) + + ( + newly_joined_room_ids, + newly_left_room_map, + ) = await self._get_newly_joined_and_left_rooms( + user_id, from_token=from_token, to_token=to_token + ) + + changes = await self._get_rewind_changes_to_current_membership_to_token( + sync_config.user, room_membership_for_user_map, to_token=to_token + ) + if changes: + # FIXME: It would be nice to avoid this copy but since + # `get_sliding_sync_rooms_for_user_from_membership_snapshots` is cached, it + # can't return a mutable value like a `dict`. We make the copy to get a + # mutable dict that we can change. We try to only make a copy when necessary + # (if we actually need to change something) as in most cases, the logic + # doesn't need to run. + room_membership_for_user_map = dict(room_membership_for_user_map) + for room_id, change in changes.items(): + if change is None: + # Remove rooms that the user joined after the `to_token` + room_membership_for_user_map.pop(room_id, None) + continue + + existing_room = room_membership_for_user_map.get(room_id) + if existing_room is not None: + # Update room membership events to the point in time of the `to_token` + room_for_user = RoomsForUserSlidingSync( + room_id=room_id, + sender=change.sender, + membership=change.membership, + event_id=change.event_id, + event_pos=change.event_pos, + room_version_id=change.room_version_id, + # We keep the state of the room though + has_known_state=existing_room.has_known_state, + room_type=existing_room.room_type, + is_encrypted=existing_room.is_encrypted, + ) + if filter_membership_for_sync( + user_id=user_id, + room_membership_for_user=room_for_user, + newly_left=room_id in newly_left_room_map, + ): + room_membership_for_user_map[room_id] = room_for_user + else: + room_membership_for_user_map.pop(room_id, None) + + # Add back `newly_left` rooms (rooms left in the from -> to token range). + # + # We do this because `get_sliding_sync_rooms_for_user_from_membership_snapshots(...)` doesn't include + # rooms that the user left themselves as it's more efficient to add them back + # here than to fetch all rooms and then filter out the old left rooms. The user + # only leaves a room once in a blue moon so this barely needs to run. + # + missing_newly_left_rooms = ( + newly_left_room_map.keys() - room_membership_for_user_map.keys() + ) + if missing_newly_left_rooms: + # FIXME: It would be nice to avoid this copy but since + # `get_sliding_sync_rooms_for_user_from_membership_snapshots` is cached, it + # can't return a mutable value like a `dict`. We make the copy to get a + # mutable dict that we can change. We try to only make a copy when necessary + # (if we actually need to change something) as in most cases, the logic + # doesn't need to run. + room_membership_for_user_map = dict(room_membership_for_user_map) + for room_id in missing_newly_left_rooms: + newly_left_room_for_user = newly_left_room_map[room_id] + # This should be a given + assert newly_left_room_for_user.membership == Membership.LEAVE + + # Add back `newly_left` rooms + # + # Check for membership and state in the Sliding Sync tables as it's just + # another membership + newly_left_room_for_user_sliding_sync = ( + await self.store.get_sliding_sync_room_for_user(user_id, room_id) + ) + # If the membership exists, it's just a normal user left the room on + # their own + if newly_left_room_for_user_sliding_sync is not None: + if filter_membership_for_sync( + user_id=user_id, + room_membership_for_user=newly_left_room_for_user_sliding_sync, + newly_left=room_id in newly_left_room_map, + ): + room_membership_for_user_map[room_id] = ( + newly_left_room_for_user_sliding_sync + ) + else: + room_membership_for_user_map.pop(room_id, None) + + change = changes.get(room_id) + if change is not None: + # Update room membership events to the point in time of the `to_token` + room_for_user = RoomsForUserSlidingSync( + room_id=room_id, + sender=change.sender, + membership=change.membership, + event_id=change.event_id, + event_pos=change.event_pos, + room_version_id=change.room_version_id, + # We keep the state of the room though + has_known_state=newly_left_room_for_user_sliding_sync.has_known_state, + room_type=newly_left_room_for_user_sliding_sync.room_type, + is_encrypted=newly_left_room_for_user_sliding_sync.is_encrypted, + ) + if filter_membership_for_sync( + user_id=user_id, + room_membership_for_user=room_for_user, + newly_left=room_id in newly_left_room_map, + ): + room_membership_for_user_map[room_id] = room_for_user + else: + room_membership_for_user_map.pop(room_id, None) + + # If we are `newly_left` from the room but can't find any membership, + # then we have been "state reset" out of the room + else: + # Get the state at the time. We can't read from the Sliding Sync + # tables because the user has no membership in the room according to + # the state (thanks to the state reset). + # + # Note: `room_type` never changes, so we can just get current room + # type + room_type = await self.store.get_room_type(room_id) + has_known_state = room_type is not ROOM_UNKNOWN_SENTINEL + if isinstance(room_type, StateSentinel): + room_type = None + + # Get the encryption status at the time of the token + is_encrypted = await self.get_is_encrypted_for_room_at_token( + room_id, + newly_left_room_for_user.event_pos.to_room_stream_token(), + ) + + room_for_user = RoomsForUserSlidingSync( + room_id=room_id, + sender=newly_left_room_for_user.sender, + membership=newly_left_room_for_user.membership, + event_id=newly_left_room_for_user.event_id, + event_pos=newly_left_room_for_user.event_pos, + room_version_id=newly_left_room_for_user.room_version_id, + has_known_state=has_known_state, + room_type=room_type, + is_encrypted=is_encrypted, + ) + if filter_membership_for_sync( + user_id=user_id, + room_membership_for_user=room_for_user, + newly_left=room_id in newly_left_room_map, + ): + room_membership_for_user_map[room_id] = room_for_user + else: + room_membership_for_user_map.pop(room_id, None) + + dm_room_ids = await self._get_dm_rooms_for_user(user_id) + + if sync_config.lists: + sync_room_map = room_membership_for_user_map + with start_active_span("assemble_sliding_window_lists"): + for list_key, list_config in sync_config.lists.items(): + # Apply filters + filtered_sync_room_map = sync_room_map + if list_config.filters is not None: + filtered_sync_room_map = await self.filter_rooms_using_tables( + user_id, + sync_room_map, + previous_connection_state, + list_config.filters, + to_token, + dm_room_ids, + ) + + # Find which rooms are partially stated and may need to be filtered out + # depending on the `required_state` requested (see below). + partial_state_rooms = await self.store.get_partial_rooms() + + # Since creating the `RoomSyncConfig` takes some work, let's just do it + # once. + room_sync_config = RoomSyncConfig.from_room_config(list_config) + + # Exclude partially-stated rooms if we must wait for the room to be + # fully-stated + if room_sync_config.must_await_full_state(self.is_mine_id): + filtered_sync_room_map = { + room_id: room + for room_id, room in filtered_sync_room_map.items() + if room_id not in partial_state_rooms + } + + all_rooms.update(filtered_sync_room_map) + + ops: List[SlidingSyncResult.SlidingWindowList.Operation] = [] + + if list_config.ranges: + # Optimization: If we are asking for the full range, we don't + # need to sort the list. + if ( + # We're looking for a single range that covers the entire list + len(list_config.ranges) == 1 + # Range starts at 0 + and list_config.ranges[0][0] == 0 + # And the range extends to the end of the list or more. Each + # side is inclusive. + and list_config.ranges[0][1] + >= len(filtered_sync_room_map) - 1 + ): + sorted_room_info: List[RoomsForUserType] = list( + filtered_sync_room_map.values() + ) + else: + # Sort the list + sorted_room_info = await self.sort_rooms( + # Cast is safe because RoomsForUserSlidingSync is part + # of the `RoomsForUserType` union. Why can't it detect this? + cast( + Dict[str, RoomsForUserType], filtered_sync_room_map + ), + to_token, + # We only need to sort the rooms up to the end + # of the largest range. Both sides of range are + # inclusive so we `+ 1`. + limit=max(range[1] + 1 for range in list_config.ranges), + ) + + for range in list_config.ranges: + room_ids_in_list: List[str] = [] + + # We're going to loop through the sorted list of rooms starting + # at the range start index and keep adding rooms until we fill + # up the range or run out of rooms. + # + # Both sides of range are inclusive so we `+ 1` + max_num_rooms = range[1] - range[0] + 1 + for room_membership in sorted_room_info[range[0] :]: + room_id = room_membership.room_id + + if len(room_ids_in_list) >= max_num_rooms: + break + + # Take the superset of the `RoomSyncConfig` for each room. + # + # Update our `relevant_room_map` with the room we're going + # to display and need to fetch more info about. + existing_room_sync_config = relevant_room_map.get( + room_id + ) + if existing_room_sync_config is not None: + room_sync_config = existing_room_sync_config.combine_room_sync_config( + room_sync_config + ) + + relevant_room_map[room_id] = room_sync_config + + room_ids_in_list.append(room_id) + + ops.append( + SlidingSyncResult.SlidingWindowList.Operation( + op=OperationType.SYNC, + range=range, + room_ids=room_ids_in_list, + ) + ) + + lists[list_key] = SlidingSyncResult.SlidingWindowList( + count=len(filtered_sync_room_map), + ops=ops, + ) + + if sync_config.room_subscriptions: + with start_active_span("assemble_room_subscriptions"): + # FIXME: It would be nice to avoid this copy but since + # `get_sliding_sync_rooms_for_user_from_membership_snapshots` is cached, it + # can't return a mutable value like a `dict`. We make the copy to get a + # mutable dict that we can change. We try to only make a copy when necessary + # (if we actually need to change something) as in most cases, the logic + # doesn't need to run. + room_membership_for_user_map = dict(room_membership_for_user_map) + + # Find which rooms are partially stated and may need to be filtered out + # depending on the `required_state` requested (see below). + partial_state_rooms = await self.store.get_partial_rooms() + + # Fetch any rooms that we have not already fetched from the database. + subscription_sliding_sync_rooms = ( + await self.store.get_sliding_sync_room_for_user_batch( + user_id, + sync_config.room_subscriptions.keys() + - room_membership_for_user_map.keys(), + ) + ) + room_membership_for_user_map.update(subscription_sliding_sync_rooms) + + for ( + room_id, + room_subscription, + ) in sync_config.room_subscriptions.items(): + # Check if we have a membership for the room, but didn't pull it out + # above. This could be e.g. a leave that we don't pull out by + # default. + current_room_entry = room_membership_for_user_map.get(room_id) + if not current_room_entry: + # TODO: Handle rooms the user isn't in. + continue + + all_rooms.add(room_id) + + # Take the superset of the `RoomSyncConfig` for each room. + room_sync_config = RoomSyncConfig.from_room_config( + room_subscription + ) + + # Exclude partially-stated rooms if we must wait for the room to be + # fully-stated + if room_sync_config.must_await_full_state(self.is_mine_id): + if room_id in partial_state_rooms: + continue + + # Update our `relevant_room_map` with the room we're going to display + # and need to fetch more info about. + existing_room_sync_config = relevant_room_map.get(room_id) + if existing_room_sync_config is not None: + room_sync_config = ( + existing_room_sync_config.combine_room_sync_config( + room_sync_config + ) + ) + + relevant_room_map[room_id] = room_sync_config + + # Filtered subset of `relevant_room_map` for rooms that may have updates + # (in the event stream) + relevant_rooms_to_send_map = await self._filter_relevant_rooms_to_send( + previous_connection_state, from_token, relevant_room_map + ) + + return SlidingSyncInterestedRooms( + lists=lists, + relevant_room_map=relevant_room_map, + relevant_rooms_to_send_map=relevant_rooms_to_send_map, + all_rooms=all_rooms, + room_membership_for_user_map=room_membership_for_user_map, + newly_joined_rooms=newly_joined_room_ids, + newly_left_rooms=set(newly_left_room_map), + dm_room_ids=dm_room_ids, + ) + + async def _compute_interested_rooms_fallback( + self, + sync_config: SlidingSyncConfig, + previous_connection_state: "PerConnectionState", + to_token: StreamToken, + from_token: Optional[StreamToken], + ) -> SlidingSyncInterestedRooms: + """Fallback code when the database background updates haven't completed yet.""" + + ( + room_membership_for_user_map, + newly_joined_room_ids, + newly_left_room_ids, + ) = await self.get_room_membership_for_user_at_to_token( + sync_config.user, to_token, from_token + ) + + dm_room_ids = await self._get_dm_rooms_for_user(sync_config.user.to_string()) + + # Assemble sliding window lists + lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {} + # Keep track of the rooms that we can display and need to fetch more info about + relevant_room_map: Dict[str, RoomSyncConfig] = {} + # The set of room IDs of all rooms that could appear in any list. These + # include rooms that are outside the list ranges. + all_rooms: Set[str] = set() + + if sync_config.lists: + with start_active_span("assemble_sliding_window_lists"): + sync_room_map = await self.filter_rooms_relevant_for_sync( + user=sync_config.user, + room_membership_for_user_map=room_membership_for_user_map, + newly_left_room_ids=newly_left_room_ids, + ) + + for list_key, list_config in sync_config.lists.items(): + # Apply filters + filtered_sync_room_map = sync_room_map + if list_config.filters is not None: + filtered_sync_room_map = await self.filter_rooms( + sync_config.user, + sync_room_map, + previous_connection_state, + list_config.filters, + to_token, + dm_room_ids, + ) + + # Find which rooms are partially stated and may need to be filtered out + # depending on the `required_state` requested (see below). + partial_state_rooms = await self.store.get_partial_rooms() + + # Since creating the `RoomSyncConfig` takes some work, let's just do it + # once. + room_sync_config = RoomSyncConfig.from_room_config(list_config) + + # Exclude partially-stated rooms if we must wait for the room to be + # fully-stated + if room_sync_config.must_await_full_state(self.is_mine_id): + filtered_sync_room_map = { + room_id: room + for room_id, room in filtered_sync_room_map.items() + if room_id not in partial_state_rooms + } + + all_rooms.update(filtered_sync_room_map) + + # Sort the list + sorted_room_info = await self.sort_rooms( + filtered_sync_room_map, to_token + ) + + ops: List[SlidingSyncResult.SlidingWindowList.Operation] = [] + if list_config.ranges: + for range in list_config.ranges: + room_ids_in_list: List[str] = [] + + # We're going to loop through the sorted list of rooms starting + # at the range start index and keep adding rooms until we fill + # up the range or run out of rooms. + # + # Both sides of range are inclusive so we `+ 1` + max_num_rooms = range[1] - range[0] + 1 + for room_membership in sorted_room_info[range[0] :]: + room_id = room_membership.room_id + + if len(room_ids_in_list) >= max_num_rooms: + break + + # Take the superset of the `RoomSyncConfig` for each room. + # + # Update our `relevant_room_map` with the room we're going + # to display and need to fetch more info about. + existing_room_sync_config = relevant_room_map.get( + room_id + ) + if existing_room_sync_config is not None: + room_sync_config = existing_room_sync_config.combine_room_sync_config( + room_sync_config + ) + + relevant_room_map[room_id] = room_sync_config + + room_ids_in_list.append(room_id) + + ops.append( + SlidingSyncResult.SlidingWindowList.Operation( + op=OperationType.SYNC, + range=range, + room_ids=room_ids_in_list, + ) + ) + + lists[list_key] = SlidingSyncResult.SlidingWindowList( + count=len(sorted_room_info), + ops=ops, + ) + + if sync_config.room_subscriptions: + with start_active_span("assemble_room_subscriptions"): + # Find which rooms are partially stated and may need to be filtered out + # depending on the `required_state` requested (see below). + partial_state_rooms = await self.store.get_partial_rooms() + + for ( + room_id, + room_subscription, + ) in sync_config.room_subscriptions.items(): + room_membership_for_user_at_to_token = ( + await self.check_room_subscription_allowed_for_user( + room_id=room_id, + room_membership_for_user_map=room_membership_for_user_map, + to_token=to_token, + ) + ) + + # Skip this room if the user isn't allowed to see it + if not room_membership_for_user_at_to_token: + continue + + all_rooms.add(room_id) + + room_membership_for_user_map[room_id] = ( + room_membership_for_user_at_to_token + ) + + # Take the superset of the `RoomSyncConfig` for each room. + room_sync_config = RoomSyncConfig.from_room_config( + room_subscription + ) + + # Exclude partially-stated rooms if we must wait for the room to be + # fully-stated + if room_sync_config.must_await_full_state(self.is_mine_id): + if room_id in partial_state_rooms: + continue + + all_rooms.add(room_id) + + # Update our `relevant_room_map` with the room we're going to display + # and need to fetch more info about. + existing_room_sync_config = relevant_room_map.get(room_id) + if existing_room_sync_config is not None: + room_sync_config = ( + existing_room_sync_config.combine_room_sync_config( + room_sync_config + ) + ) + + relevant_room_map[room_id] = room_sync_config + + # Filtered subset of `relevant_room_map` for rooms that may have updates + # (in the event stream) + relevant_rooms_to_send_map = await self._filter_relevant_rooms_to_send( + previous_connection_state, from_token, relevant_room_map + ) + + return SlidingSyncInterestedRooms( + lists=lists, + relevant_room_map=relevant_room_map, + relevant_rooms_to_send_map=relevant_rooms_to_send_map, + all_rooms=all_rooms, + room_membership_for_user_map=room_membership_for_user_map, + newly_joined_rooms=newly_joined_room_ids, + newly_left_rooms=newly_left_room_ids, + dm_room_ids=dm_room_ids, + ) + + async def _filter_relevant_rooms_to_send( + self, + previous_connection_state: PerConnectionState, + from_token: Optional[StreamToken], + relevant_room_map: Dict[str, RoomSyncConfig], + ) -> Dict[str, RoomSyncConfig]: + """Filters the `relevant_room_map` down to those rooms that may have + updates we need to fetch and return.""" + + # Filtered subset of `relevant_room_map` for rooms that may have updates + # (in the event stream) + relevant_rooms_to_send_map: Dict[str, RoomSyncConfig] = relevant_room_map + if relevant_room_map: + with start_active_span("filter_relevant_rooms_to_send"): + if from_token: + rooms_should_send = set() + + # First we check if there are rooms that match a list/room + # subscription and have updates we need to send (i.e. either because + # we haven't sent the room down, or we have but there are missing + # updates). + for room_id, room_config in relevant_room_map.items(): + prev_room_sync_config = ( + previous_connection_state.room_configs.get(room_id) + ) + if prev_room_sync_config is not None: + # Always include rooms whose timeline limit has increased. + # (see the "XXX: Odd behavior" described below) + if ( + prev_room_sync_config.timeline_limit + < room_config.timeline_limit + ): + rooms_should_send.add(room_id) + continue + + status = previous_connection_state.rooms.have_sent_room(room_id) + if ( + # The room was never sent down before so the client needs to know + # about it regardless of any updates. + status.status == HaveSentRoomFlag.NEVER + # `PREVIOUSLY` literally means the "room was sent down before *AND* + # there are updates we haven't sent down" so we already know this + # room has updates. + or status.status == HaveSentRoomFlag.PREVIOUSLY + ): + rooms_should_send.add(room_id) + elif status.status == HaveSentRoomFlag.LIVE: + # We know that we've sent all updates up until `from_token`, + # so we just need to check if there have been updates since + # then. + pass + else: + assert_never(status.status) + + # We only need to check for new events since any state changes + # will also come down as new events. + rooms_that_have_updates = ( + self.store.get_rooms_that_might_have_updates( + relevant_room_map.keys(), from_token.room_key + ) + ) + rooms_should_send.update(rooms_that_have_updates) + relevant_rooms_to_send_map = { + room_id: room_sync_config + for room_id, room_sync_config in relevant_room_map.items() + if room_id in rooms_should_send + } + + return relevant_rooms_to_send_map + + @trace + async def _get_rewind_changes_to_current_membership_to_token( + self, + user: UserID, + rooms_for_user: Mapping[str, RoomsForUserType], + to_token: StreamToken, + ) -> Mapping[str, Optional[RoomsForUser]]: + """ + Takes the current set of rooms for a user (retrieved after the given + token), and returns the changes needed to "rewind" it to match the set of + memberships *at that token* (<= `to_token`). + + Args: + user: User to fetch rooms for + rooms_for_user: The set of rooms for the user after the `to_token`. + to_token: The token to rewind to + + Returns: + The changes to apply to rewind the the current memberships. + """ + # If the user has never joined any rooms before, we can just return an empty list + if not rooms_for_user: + return {} + + user_id = user.to_string() + + # Get the `RoomStreamToken` that represents the spot we queried up to when we got + # our membership snapshot from `get_rooms_for_local_user_where_membership_is()`. + # + # First, we need to get the max stream_ordering of each event persister instance + # that we queried events from. + instance_to_max_stream_ordering_map: Dict[str, int] = {} + for room_for_user in rooms_for_user.values(): + instance_name = room_for_user.event_pos.instance_name + stream_ordering = room_for_user.event_pos.stream + + current_instance_max_stream_ordering = ( + instance_to_max_stream_ordering_map.get(instance_name) + ) + if ( + current_instance_max_stream_ordering is None + or stream_ordering > current_instance_max_stream_ordering + ): + instance_to_max_stream_ordering_map[instance_name] = stream_ordering + + # Then assemble the `RoomStreamToken` + min_stream_pos = min(instance_to_max_stream_ordering_map.values()) + membership_snapshot_token = RoomStreamToken( + # Minimum position in the `instance_map` + stream=min_stream_pos, + instance_map=immutabledict( + { + instance_name: stream_pos + for instance_name, stream_pos in instance_to_max_stream_ordering_map.items() + if stream_pos > min_stream_pos + } + ), + ) + + # Since we fetched the users room list at some point in time after the + # tokens, we need to revert/rewind some membership changes to match the point in + # time of the `to_token`. In particular, we need to make these fixups: + # + # - a) Remove rooms that the user joined after the `to_token` + # - b) Update room membership events to the point in time of the `to_token` + + # Fetch membership changes that fall in the range from `to_token` up to + # `membership_snapshot_token` + # + # If our `to_token` is already the same or ahead of the latest room membership + # for the user, we don't need to do any "2)" fix-ups and can just straight-up + # use the room list from the snapshot as a base (nothing has changed) + current_state_delta_membership_changes_after_to_token = [] + if not membership_snapshot_token.is_before_or_eq(to_token.room_key): + current_state_delta_membership_changes_after_to_token = ( + await self.store.get_current_state_delta_membership_changes_for_user( + user_id, + from_key=to_token.room_key, + to_key=membership_snapshot_token, + excluded_room_ids=self.rooms_to_exclude_globally, + ) + ) + + if not current_state_delta_membership_changes_after_to_token: + # There have been no membership changes, so we can early return. + return {} + + # Otherwise we're about to make changes to `rooms_for_user`, so we turn + # it into a mutable dict. + changes: Dict[str, Optional[RoomsForUser]] = {} + + # Assemble a list of the first membership event after the `to_token` so we can + # step backward to the previous membership that would apply to the from/to + # range. + first_membership_change_by_room_id_after_to_token: Dict[ + str, CurrentStateDeltaMembership + ] = {} + for membership_change in current_state_delta_membership_changes_after_to_token: + # Only set if we haven't already set it + first_membership_change_by_room_id_after_to_token.setdefault( + membership_change.room_id, membership_change + ) + + # Since we fetched a snapshot of the users room list at some point in time after + # the from/to tokens, we need to revert/rewind some membership changes to match + # the point in time of the `to_token`. + for ( + room_id, + first_membership_change_after_to_token, + ) in first_membership_change_by_room_id_after_to_token.items(): + # 1a) Remove rooms that the user joined after the `to_token` + if first_membership_change_after_to_token.prev_event_id is None: + changes[room_id] = None + # 1b) 1c) From the first membership event after the `to_token`, step backward to the + # previous membership that would apply to the from/to range. + else: + # We don't expect these fields to be `None` if we have a `prev_event_id` + # but we're being defensive since it's possible that the prev event was + # culled from the database. + if ( + first_membership_change_after_to_token.prev_event_pos is not None + and first_membership_change_after_to_token.prev_membership + is not None + and first_membership_change_after_to_token.prev_sender is not None + ): + # We need to know the room version ID, which we normally we + # can get from the current membership, but if we don't have + # that then we need to query the DB. + current_membership = rooms_for_user.get(room_id) + if current_membership is not None: + room_version_id = current_membership.room_version_id + else: + room_version_id = await self.store.get_room_version_id(room_id) + + changes[room_id] = RoomsForUser( + room_id=room_id, + event_id=first_membership_change_after_to_token.prev_event_id, + event_pos=first_membership_change_after_to_token.prev_event_pos, + membership=first_membership_change_after_to_token.prev_membership, + sender=first_membership_change_after_to_token.prev_sender, + room_version_id=room_version_id, + ) + else: + # If we can't find the previous membership event, we shouldn't + # include the room in the sync response since we can't determine the + # exact membership state and shouldn't rely on the current snapshot. + changes[room_id] = None + + return changes + + @trace + async def get_room_membership_for_user_at_to_token( + self, + user: UserID, + to_token: StreamToken, + from_token: Optional[StreamToken], + ) -> Tuple[Dict[str, RoomsForUserType], AbstractSet[str], AbstractSet[str]]: + """ + Fetch room IDs that the user has had membership in (the full room list including + long-lost left rooms that will be filtered, sorted, and sliced). + + We're looking for rooms where the user has had any sort of membership in the + token range (> `from_token` and <= `to_token`) + + In order for bans/kicks to not show up, you need to `/forget` those rooms. This + doesn't modify the event itself though and only adds the `forgotten` flag to the + `room_memberships` table in Synapse. There isn't a way to tell when a room was + forgotten at the moment so we can't factor it into the token range. + + Args: + user: User to fetch rooms for + to_token: The token to fetch rooms up to. + from_token: The point in the stream to sync from. + + Returns: + A 3-tuple of: + - A dictionary of room IDs that the user has had membership in along with + membership information in that room at the time of `to_token`. + - Set of newly joined rooms + - Set of newly left rooms + """ + user_id = user.to_string() + + # First grab a current snapshot rooms for the user + # (also handles forgotten rooms) + room_for_user_list = await self.store.get_rooms_for_local_user_where_membership_is( + user_id=user_id, + # We want to fetch any kind of membership (joined and left rooms) in order + # to get the `event_pos` of the latest room membership event for the + # user. + membership_list=Membership.LIST, + excluded_rooms=self.rooms_to_exclude_globally, + ) + + # We filter out unknown room versions before we try and load any + # metadata about the room. They shouldn't go down sync anyway, and their + # metadata may be in a broken state. + room_for_user_list = [ + room_for_user + for room_for_user in room_for_user_list + if room_for_user.room_version_id in KNOWN_ROOM_VERSIONS + ] + + # Remove invites from ignored users + ignored_users = await self.store.ignored_users(user_id) + if ignored_users: + room_for_user_list = [ + room_for_user + for room_for_user in room_for_user_list + if not ( + room_for_user.membership == Membership.INVITE + and room_for_user.sender in ignored_users + ) + ] + + ( + newly_joined_room_ids, + newly_left_room_map, + ) = await self._get_newly_joined_and_left_rooms_fallback( + user_id, to_token=to_token, from_token=from_token + ) + + # If the user has never joined any rooms before, we can just return an empty + # list. We also have to check the `newly_left_room_map` in case someone was + # state reset out of all of the rooms they were in. + if not room_for_user_list and not newly_left_room_map: + return {}, set(), set() + + # Since we fetched the users room list at some point in time after the + # tokens, we need to revert/rewind some membership changes to match the point in + # time of the `to_token`. + rooms_for_user: Dict[str, RoomsForUserType] = { + room.room_id: room for room in room_for_user_list + } + changes = await self._get_rewind_changes_to_current_membership_to_token( + user, rooms_for_user, to_token + ) + for room_id, change_room_for_user in changes.items(): + if change_room_for_user is None: + rooms_for_user.pop(room_id, None) + else: + rooms_for_user[room_id] = change_room_for_user + + # Ensure we have entries for rooms that the user has been "state reset" + # out of. These are rooms appear in the `newly_left_rooms` map but + # aren't in the `rooms_for_user` map. + for room_id, newly_left_room_for_user in newly_left_room_map.items(): + # If we already know about the room, it's not a state reset + if room_id in rooms_for_user: + continue + + # This should be true if it's a state reset + assert newly_left_room_for_user.membership is Membership.LEAVE + assert newly_left_room_for_user.event_id is None + assert newly_left_room_for_user.sender is None + + rooms_for_user[room_id] = newly_left_room_for_user + + return rooms_for_user, newly_joined_room_ids, set(newly_left_room_map) + + @trace + async def _get_newly_joined_and_left_rooms( + self, + user_id: str, + to_token: StreamToken, + from_token: Optional[StreamToken], + ) -> Tuple[AbstractSet[str], Mapping[str, RoomsForUserStateReset]]: + """Fetch the sets of rooms that the user newly joined or left in the + given token range. + + Note: there may be rooms in the newly left rooms where the user was + "state reset" out of the room, and so that room would not be part of the + "current memberships" of the user. + + Returns: + A 2-tuple of newly joined room IDs and a map of newly_left room + IDs to the `RoomsForUserStateReset` entry. + + We're using `RoomsForUserStateReset` but that doesn't necessarily mean the + user was state reset of the rooms. It's just that the `event_id`/`sender` + are optional and we can't tell the difference between the server leaving the + room when the user was the last person participating in the room and left or + was state reset out of the room. To actually check for a state reset, you + need to check if a membership still exists in the room. + """ + + newly_joined_room_ids: Set[str] = set() + newly_left_room_map: Dict[str, RoomsForUserStateReset] = {} + + if not from_token: + return newly_joined_room_ids, newly_left_room_map + + changes = await self.store.get_sliding_sync_membership_changes( + user_id, + from_key=from_token.room_key, + to_key=to_token.room_key, + excluded_room_ids=set(self.rooms_to_exclude_globally), + ) + + for room_id, entry in changes.items(): + if entry.membership == Membership.JOIN: + newly_joined_room_ids.add(room_id) + elif entry.membership == Membership.LEAVE: + newly_left_room_map[room_id] = entry + + return newly_joined_room_ids, newly_left_room_map + + @trace + async def _get_newly_joined_and_left_rooms_fallback( + self, + user_id: str, + to_token: StreamToken, + from_token: Optional[StreamToken], + ) -> Tuple[AbstractSet[str], Mapping[str, RoomsForUserStateReset]]: + """Fetch the sets of rooms that the user newly joined or left in the + given token range. + + Note: there may be rooms in the newly left rooms where the user was + "state reset" out of the room, and so that room would not be part of the + "current memberships" of the user. + + Returns: + A 2-tuple of newly joined room IDs and a map of newly_left room + IDs to the `RoomsForUserStateReset` entry. + + We're using `RoomsForUserStateReset` but that doesn't necessarily mean the + user was state reset of the rooms. It's just that the `event_id`/`sender` + are optional and we can't tell the difference between the server leaving the + room when the user was the last person participating in the room and left or + was state reset out of the room. To actually check for a state reset, you + need to check if a membership still exists in the room. + """ + newly_joined_room_ids: Set[str] = set() + newly_left_room_map: Dict[str, RoomsForUserStateReset] = {} + + # We need to figure out the + # + # - 1) Figure out which rooms are `newly_left` rooms (> `from_token` and <= `to_token`) + # - 2) Figure out which rooms are `newly_joined` (> `from_token` and <= `to_token`) + + # 1) Fetch membership changes that fall in the range from `from_token` up to `to_token` + current_state_delta_membership_changes_in_from_to_range = [] + if from_token: + current_state_delta_membership_changes_in_from_to_range = ( + await self.store.get_current_state_delta_membership_changes_for_user( + user_id, + from_key=from_token.room_key, + to_key=to_token.room_key, + excluded_room_ids=self.rooms_to_exclude_globally, + ) + ) + + # 1) Assemble a list of the last membership events in some given ranges. Someone + # could have left and joined multiple times during the given range but we only + # care about end-result so we grab the last one. + last_membership_change_by_room_id_in_from_to_range: Dict[ + str, CurrentStateDeltaMembership + ] = {} + # We also want to assemble a list of the first membership events during the token + # range so we can step backward to the previous membership that would apply to + # before the token range to see if we have `newly_joined` the room. + first_membership_change_by_room_id_in_from_to_range: Dict[ + str, CurrentStateDeltaMembership + ] = {} + # Keep track if the room has a non-join event in the token range so we can later + # tell if it was a `newly_joined` room. If the last membership event in the + # token range is a join and there is also some non-join in the range, we know + # they `newly_joined`. + has_non_join_event_by_room_id_in_from_to_range: Dict[str, bool] = {} + for ( + membership_change + ) in current_state_delta_membership_changes_in_from_to_range: + room_id = membership_change.room_id + + last_membership_change_by_room_id_in_from_to_range[room_id] = ( + membership_change + ) + # Only set if we haven't already set it + first_membership_change_by_room_id_in_from_to_range.setdefault( + room_id, membership_change + ) + + if membership_change.membership != Membership.JOIN: + has_non_join_event_by_room_id_in_from_to_range[room_id] = True + + # 1) Fixup + # + # 2) We also want to assemble a list of possibly newly joined rooms. Someone + # could have left and joined multiple times during the given range but we only + # care about whether they are joined at the end of the token range so we are + # working with the last membership even in the token range. + possibly_newly_joined_room_ids = set() + for ( + last_membership_change_in_from_to_range + ) in last_membership_change_by_room_id_in_from_to_range.values(): + room_id = last_membership_change_in_from_to_range.room_id + + # 2) + if last_membership_change_in_from_to_range.membership == Membership.JOIN: + possibly_newly_joined_room_ids.add(room_id) + + # 1) Figure out newly_left rooms (> `from_token` and <= `to_token`). + if last_membership_change_in_from_to_range.membership == Membership.LEAVE: + # 1) Mark this room as `newly_left` + newly_left_room_map[room_id] = RoomsForUserStateReset( + room_id=room_id, + sender=last_membership_change_in_from_to_range.sender, + membership=Membership.LEAVE, + event_id=last_membership_change_in_from_to_range.event_id, + event_pos=last_membership_change_in_from_to_range.event_pos, + room_version_id=await self.store.get_room_version_id(room_id), + ) + + # 2) Figure out `newly_joined` + for room_id in possibly_newly_joined_room_ids: + has_non_join_in_from_to_range = ( + has_non_join_event_by_room_id_in_from_to_range.get(room_id, False) + ) + # If the last membership event in the token range is a join and there is + # also some non-join in the range, we know they `newly_joined`. + if has_non_join_in_from_to_range: + # We found a `newly_joined` room (we left and joined within the token range) + newly_joined_room_ids.add(room_id) + else: + prev_event_id = first_membership_change_by_room_id_in_from_to_range[ + room_id + ].prev_event_id + prev_membership = first_membership_change_by_room_id_in_from_to_range[ + room_id + ].prev_membership + + if prev_event_id is None: + # We found a `newly_joined` room (we are joining the room for the + # first time within the token range) + newly_joined_room_ids.add(room_id) + # Last resort, we need to step back to the previous membership event + # just before the token range to see if we're joined then or not. + elif prev_membership != Membership.JOIN: + # We found a `newly_joined` room (we left before the token range + # and joined within the token range) + newly_joined_room_ids.add(room_id) + + return newly_joined_room_ids, newly_left_room_map + + @trace + async def _get_dm_rooms_for_user( + self, + user_id: str, + ) -> AbstractSet[str]: + """Get the set of DM rooms for the user.""" + + # We're using global account data (`m.direct`) instead of checking for + # `is_direct` on membership events because that property only appears for + # the invitee membership event (doesn't show up for the inviter). + # + # We're unable to take `to_token` into account for global account data since + # we only keep track of the latest account data for the user. + dm_map = await self.store.get_global_account_data_by_type_for_user( + user_id, AccountDataTypes.DIRECT + ) + + # Flatten out the map. Account data is set by the client so it needs to be + # scrutinized. + dm_room_id_set = set() + if isinstance(dm_map, dict): + for room_ids in dm_map.values(): + # Account data should be a list of room IDs. Ignore anything else + if isinstance(room_ids, list): + for room_id in room_ids: + if isinstance(room_id, str): + dm_room_id_set.add(room_id) + + return dm_room_id_set + + @trace + async def filter_rooms_relevant_for_sync( + self, + user: UserID, + room_membership_for_user_map: Dict[str, RoomsForUserType], + newly_left_room_ids: AbstractSet[str], + ) -> Dict[str, RoomsForUserType]: + """ + Filter room IDs that should/can be listed for this user in the sync response (the + full room list that will be further filtered, sorted, and sliced). + + We're looking for rooms where the user has the following state in the token + range (> `from_token` and <= `to_token`): + + - `invite`, `join`, `knock`, `ban` membership events + - Kicks (`leave` membership events where `sender` is different from the + `user_id`/`state_key`) + - `newly_left` (rooms that were left during the given token range) + - In order for bans/kicks to not show up in sync, you need to `/forget` those + rooms. This doesn't modify the event itself though and only adds the + `forgotten` flag to the `room_memberships` table in Synapse. There isn't a way + to tell when a room was forgotten at the moment so we can't factor it into the + from/to range. + + Args: + user: User that is syncing + room_membership_for_user_map: Room membership for the user + newly_left_room_ids: The set of room IDs we have newly left + + Returns: + A dictionary of room IDs that should be listed in the sync response along + with membership information in that room at the time of `to_token`. + """ + user_id = user.to_string() + + # Filter rooms to only what we're interested to sync with + filtered_sync_room_map = { + room_id: room_membership_for_user + for room_id, room_membership_for_user in room_membership_for_user_map.items() + if filter_membership_for_sync( + user_id=user_id, + room_membership_for_user=room_membership_for_user, + newly_left=room_id in newly_left_room_ids, + ) + } + + return filtered_sync_room_map + + async def check_room_subscription_allowed_for_user( + self, + room_id: str, + room_membership_for_user_map: Dict[str, RoomsForUserType], + to_token: StreamToken, + ) -> Optional[RoomsForUserType]: + """ + Check whether the user is allowed to see the room based on whether they have + ever had membership in the room or if the room is `world_readable`. + + Similar to `check_user_in_room_or_world_readable(...)` + + Args: + room_id: Room to check + room_membership_for_user_map: Room membership for the user at the time of + the `to_token` (<= `to_token`). + to_token: The token to fetch rooms up to. + + Returns: + The room membership for the user if they are allowed to subscribe to the + room else `None`. + """ + + # We can first check if they are already allowed to see the room based + # on our previous work to assemble the `room_membership_for_user_map`. + # + # If they have had any membership in the room over time (up to the `to_token`), + # let them subscribe and see what they can. + existing_membership_for_user = room_membership_for_user_map.get(room_id) + if existing_membership_for_user is not None: + return existing_membership_for_user + + # TODO: Handle `world_readable` rooms + return None + + # If the room is `world_readable`, it doesn't matter whether they can join, + # everyone can see the room. + # not_in_room_membership_for_user = _RoomMembershipForUser( + # room_id=room_id, + # event_id=None, + # event_pos=None, + # membership=None, + # sender=None, + # newly_joined=False, + # newly_left=False, + # is_dm=False, + # ) + # room_state = await self.get_current_state_at( + # room_id=room_id, + # room_membership_for_user_at_to_token=not_in_room_membership_for_user, + # state_filter=StateFilter.from_types( + # [(EventTypes.RoomHistoryVisibility, "")] + # ), + # to_token=to_token, + # ) + + # visibility_event = room_state.get((EventTypes.RoomHistoryVisibility, "")) + # if ( + # visibility_event is not None + # and visibility_event.content.get("history_visibility") + # == HistoryVisibility.WORLD_READABLE + # ): + # return not_in_room_membership_for_user + + # return None + + @trace + async def _bulk_get_stripped_state_for_rooms_from_sync_room_map( + self, + room_ids: StrCollection, + sync_room_map: Dict[str, RoomsForUserType], + ) -> Dict[str, Optional[StateMap[StrippedStateEvent]]]: + """ + Fetch stripped state for a list of room IDs. Stripped state is only + applicable to invite/knock rooms. Other rooms will have `None` as their + stripped state. + + For invite rooms, we pull from `unsigned.invite_room_state`. + For knock rooms, we pull from `unsigned.knock_room_state`. + + Args: + room_ids: Room IDs to fetch stripped state for + sync_room_map: Dictionary of room IDs to sort along with membership + information in the room at the time of `to_token`. + + Returns: + Mapping from room_id to mapping of (type, state_key) to stripped state + event. + """ + room_id_to_stripped_state_map: Dict[ + str, Optional[StateMap[StrippedStateEvent]] + ] = {} + + # Fetch what we haven't before + room_ids_to_fetch = [ + room_id + for room_id in room_ids + if room_id not in room_id_to_stripped_state_map + ] + + # Gather a list of event IDs we can grab stripped state from + invite_or_knock_event_ids: List[str] = [] + for room_id in room_ids_to_fetch: + if sync_room_map[room_id].membership in ( + Membership.INVITE, + Membership.KNOCK, + ): + event_id = sync_room_map[room_id].event_id + # If this is an invite/knock then there should be an event_id + assert event_id is not None + invite_or_knock_event_ids.append(event_id) + else: + room_id_to_stripped_state_map[room_id] = None + + invite_or_knock_events = await self.store.get_events(invite_or_knock_event_ids) + for invite_or_knock_event in invite_or_knock_events.values(): + room_id = invite_or_knock_event.room_id + membership = invite_or_knock_event.membership + + raw_stripped_state_events = None + if membership == Membership.INVITE: + invite_room_state = invite_or_knock_event.unsigned.get( + "invite_room_state" + ) + raw_stripped_state_events = invite_room_state + elif membership == Membership.KNOCK: + knock_room_state = invite_or_knock_event.unsigned.get( + "knock_room_state" + ) + raw_stripped_state_events = knock_room_state + else: + raise AssertionError( + f"Unexpected membership {membership} (this is a problem with Synapse itself)" + ) + + stripped_state_map: Optional[MutableStateMap[StrippedStateEvent]] = None + # Scrutinize unsigned things. `raw_stripped_state_events` should be a list + # of stripped events + if raw_stripped_state_events is not None: + stripped_state_map = {} + if isinstance(raw_stripped_state_events, list): + for raw_stripped_event in raw_stripped_state_events: + stripped_state_event = parse_stripped_state_event( + raw_stripped_event + ) + if stripped_state_event is not None: + stripped_state_map[ + ( + stripped_state_event.type, + stripped_state_event.state_key, + ) + ] = stripped_state_event + + room_id_to_stripped_state_map[room_id] = stripped_state_map + + return room_id_to_stripped_state_map + + @trace + async def _bulk_get_partial_current_state_content_for_rooms( + self, + content_type: Literal[ + # `content.type` from `EventTypes.Create`` + "room_type", + # `content.algorithm` from `EventTypes.RoomEncryption` + "room_encryption", + ], + room_ids: Set[str], + sync_room_map: Dict[str, RoomsForUserType], + to_token: StreamToken, + room_id_to_stripped_state_map: Dict[ + str, Optional[StateMap[StrippedStateEvent]] + ], + ) -> Mapping[str, Union[Optional[str], StateSentinel]]: + """ + Get the given state event content for a list of rooms. First we check the + current state of the room, then fallback to stripped state if available, then + historical state. + + Args: + content_type: Which content to grab + room_ids: Room IDs to fetch the given content field for. + sync_room_map: Dictionary of room IDs to sort along with membership + information in the room at the time of `to_token`. + to_token: We filter based on the state of the room at this token + room_id_to_stripped_state_map: This does not need to be filled in before + calling this function. Mapping from room_id to mapping of (type, state_key) + to stripped state event. Modified in place when we fetch new rooms so we can + save work next time this function is called. + + Returns: + A mapping from room ID to the state event content if the room has + the given state event (event_type, ""), otherwise `None`. Rooms unknown to + this server will return `ROOM_UNKNOWN_SENTINEL`. + """ + room_id_to_content: Dict[str, Union[Optional[str], StateSentinel]] = {} + + # As a bulk shortcut, use the current state if the server is particpating in the + # room (meaning we have current state). Ideally, for leave/ban rooms, we would + # want the state at the time of the membership instead of current state to not + # leak anything but we consider the create/encryption stripped state events to + # not be a secret given they are often set at the start of the room and they are + # normally handed out on invite/knock. + # + # Be mindful to only use this for non-sensitive details. For example, even + # though the room name/avatar/topic are also stripped state, they seem a lot + # more senstive to leak the current state value of. + # + # Since this function is cached, we need to make a mutable copy via + # `dict(...)`. + event_type = "" + event_content_field = "" + if content_type == "room_type": + event_type = EventTypes.Create + event_content_field = EventContentFields.ROOM_TYPE + room_id_to_content = dict(await self.store.bulk_get_room_type(room_ids)) + elif content_type == "room_encryption": + event_type = EventTypes.RoomEncryption + event_content_field = EventContentFields.ENCRYPTION_ALGORITHM + room_id_to_content = dict( + await self.store.bulk_get_room_encryption(room_ids) + ) + else: + assert_never(content_type) + + room_ids_with_results = [ + room_id + for room_id, content_field in room_id_to_content.items() + if content_field is not ROOM_UNKNOWN_SENTINEL + ] + + # We might not have current room state for remote invite/knocks if we are + # the first person on our server to see the room. The best we can do is look + # in the optional stripped state from the invite/knock event. + room_ids_without_results = room_ids.difference( + chain( + room_ids_with_results, + [ + room_id + for room_id, stripped_state_map in room_id_to_stripped_state_map.items() + if stripped_state_map is not None + ], + ) + ) + room_id_to_stripped_state_map.update( + await self._bulk_get_stripped_state_for_rooms_from_sync_room_map( + room_ids_without_results, sync_room_map + ) + ) + + # Update our `room_id_to_content` map based on the stripped state + # (applies to invite/knock rooms) + rooms_ids_without_stripped_state: Set[str] = set() + for room_id in room_ids_without_results: + stripped_state_map = room_id_to_stripped_state_map.get( + room_id, Sentinel.UNSET_SENTINEL + ) + assert stripped_state_map is not Sentinel.UNSET_SENTINEL, ( + f"Stripped state left unset for room {room_id}. " + + "Make sure you're calling `_bulk_get_stripped_state_for_rooms_from_sync_room_map(...)` " + + "with that room_id. (this is a problem with Synapse itself)" + ) + + # If there is some stripped state, we assume the remote server passed *all* + # of the potential stripped state events for the room. + if stripped_state_map is not None: + create_stripped_event = stripped_state_map.get((EventTypes.Create, "")) + stripped_event = stripped_state_map.get((event_type, "")) + # Sanity check that we at-least have the create event + if create_stripped_event is not None: + if stripped_event is not None: + room_id_to_content[room_id] = stripped_event.content.get( + event_content_field + ) + else: + # Didn't see the state event we're looking for in the stripped + # state so we can assume relevant content field is `None`. + room_id_to_content[room_id] = None + else: + rooms_ids_without_stripped_state.add(room_id) + + # Last resort, we might not have current room state for rooms that the + # server has left (no one local is in the room) but we can look at the + # historical state. + # + # Update our `room_id_to_content` map based on the state at the time of + # the membership event. + for room_id in rooms_ids_without_stripped_state: + # TODO: It would be nice to look this up in a bulk way (N+1 queries) + # + # TODO: `get_state_at(...)` doesn't take into account the "current state". + room_state = await self.storage_controllers.state.get_state_at( + room_id=room_id, + stream_position=to_token.copy_and_replace( + StreamKeyType.ROOM, + sync_room_map[room_id].event_pos.to_room_stream_token(), + ), + state_filter=StateFilter.from_types( + [ + (EventTypes.Create, ""), + (event_type, ""), + ] + ), + # Partially-stated rooms should have all state events except for + # remote membership events so we don't need to wait at all because + # we only want the create event and some non-member event. + await_full_state=False, + ) + # We can use the create event as a canary to tell whether the server has + # seen the room before + create_event = room_state.get((EventTypes.Create, "")) + state_event = room_state.get((event_type, "")) + + if create_event is None: + # Skip for unknown rooms + continue + + if state_event is not None: + room_id_to_content[room_id] = state_event.content.get( + event_content_field + ) + else: + # Didn't see the state event we're looking for in the stripped + # state so we can assume relevant content field is `None`. + room_id_to_content[room_id] = None + + return room_id_to_content + + @trace + async def filter_rooms( + self, + user: UserID, + sync_room_map: Dict[str, RoomsForUserType], + previous_connection_state: PerConnectionState, + filters: SlidingSyncConfig.SlidingSyncList.Filters, + to_token: StreamToken, + dm_room_ids: AbstractSet[str], + ) -> Dict[str, RoomsForUserType]: + """ + Filter rooms based on the sync request. + + Args: + user: User to filter rooms for + sync_room_map: Dictionary of room IDs to sort along with membership + information in the room at the time of `to_token`. + filters: Filters to apply + to_token: We filter based on the state of the room at this token + dm_room_ids: Set of room IDs that are DMs for the user + + Returns: + A filtered dictionary of room IDs along with membership information in the + room at the time of `to_token`. + """ + user_id = user.to_string() + + room_id_to_stripped_state_map: Dict[ + str, Optional[StateMap[StrippedStateEvent]] + ] = {} + + filtered_room_id_set = set(sync_room_map.keys()) + + # Filter for Direct-Message (DM) rooms + if filters.is_dm is not None: + with start_active_span("filters.is_dm"): + if filters.is_dm: + # Only DM rooms please + filtered_room_id_set = { + room_id + for room_id in filtered_room_id_set + if room_id in dm_room_ids + } + else: + # Only non-DM rooms please + filtered_room_id_set = { + room_id + for room_id in filtered_room_id_set + if room_id not in dm_room_ids + } + + if filters.spaces is not None: + with start_active_span("filters.spaces"): + raise NotImplementedError() + + # Filter for encrypted rooms + if filters.is_encrypted is not None: + with start_active_span("filters.is_encrypted"): + room_id_to_encryption = ( + await self._bulk_get_partial_current_state_content_for_rooms( + content_type="room_encryption", + room_ids=filtered_room_id_set, + to_token=to_token, + sync_room_map=sync_room_map, + room_id_to_stripped_state_map=room_id_to_stripped_state_map, + ) + ) + + # Make a copy so we don't run into an error: `Set changed size during + # iteration`, when we filter out and remove items + for room_id in filtered_room_id_set.copy(): + encryption = room_id_to_encryption.get( + room_id, ROOM_UNKNOWN_SENTINEL + ) + + # Just remove rooms if we can't determine their encryption status + if encryption is ROOM_UNKNOWN_SENTINEL: + filtered_room_id_set.remove(room_id) + continue + + # If we're looking for encrypted rooms, filter out rooms that are not + # encrypted and vice versa + is_encrypted = encryption is not None + if (filters.is_encrypted and not is_encrypted) or ( + not filters.is_encrypted and is_encrypted + ): + filtered_room_id_set.remove(room_id) + + # Filter for rooms that the user has been invited to + if filters.is_invite is not None: + with start_active_span("filters.is_invite"): + # Make a copy so we don't run into an error: `Set changed size during + # iteration`, when we filter out and remove items + for room_id in filtered_room_id_set.copy(): + room_for_user = sync_room_map[room_id] + # If we're looking for invite rooms, filter out rooms that the user is + # not invited to and vice versa + if ( + filters.is_invite + and room_for_user.membership != Membership.INVITE + ) or ( + not filters.is_invite + and room_for_user.membership == Membership.INVITE + ): + filtered_room_id_set.remove(room_id) + + # Filter by room type (space vs room, etc). A room must match one of the types + # provided in the list. `None` is a valid type for rooms which do not have a + # room type. + if filters.room_types is not None or filters.not_room_types is not None: + with start_active_span("filters.room_types"): + room_id_to_type = ( + await self._bulk_get_partial_current_state_content_for_rooms( + content_type="room_type", + room_ids=filtered_room_id_set, + to_token=to_token, + sync_room_map=sync_room_map, + room_id_to_stripped_state_map=room_id_to_stripped_state_map, + ) + ) + + # Make a copy so we don't run into an error: `Set changed size during + # iteration`, when we filter out and remove items + for room_id in filtered_room_id_set.copy(): + room_type = room_id_to_type.get(room_id, ROOM_UNKNOWN_SENTINEL) + + # Just remove rooms if we can't determine their type + if room_type is ROOM_UNKNOWN_SENTINEL: + filtered_room_id_set.remove(room_id) + continue + + if ( + filters.room_types is not None + and room_type not in filters.room_types + ): + filtered_room_id_set.remove(room_id) + continue + + if ( + filters.not_room_types is not None + and room_type in filters.not_room_types + ): + filtered_room_id_set.remove(room_id) + continue + + if filters.room_name_like is not None: + with start_active_span("filters.room_name_like"): + # TODO: The room name is a bit more sensitive to leak than the + # create/encryption event. Maybe we should consider a better way to fetch + # historical state before implementing this. + # + # room_id_to_create_content = await self._bulk_get_partial_current_state_content_for_rooms( + # content_type="room_name", + # room_ids=filtered_room_id_set, + # to_token=to_token, + # sync_room_map=sync_room_map, + # room_id_to_stripped_state_map=room_id_to_stripped_state_map, + # ) + raise NotImplementedError() + + # Filter by room tags according to the users account data + if filters.tags is not None or filters.not_tags is not None: + with start_active_span("filters.tags"): + # Fetch the user tags for their rooms + room_tags = await self.store.get_tags_for_user(user_id) + room_id_to_tag_name_set: Dict[str, Set[str]] = { + room_id: set(tags.keys()) for room_id, tags in room_tags.items() + } + + if filters.tags is not None: + tags_set = set(filters.tags) + filtered_room_id_set = { + room_id + for room_id in filtered_room_id_set + # Remove rooms that don't have one of the tags in the filter + if room_id_to_tag_name_set.get(room_id, set()).intersection( + tags_set + ) + } + + if filters.not_tags is not None: + not_tags_set = set(filters.not_tags) + filtered_room_id_set = { + room_id + for room_id in filtered_room_id_set + # Remove rooms if they have any of the tags in the filter + if not room_id_to_tag_name_set.get(room_id, set()).intersection( + not_tags_set + ) + } + + # Keep rooms if the user has been state reset out of it but we previously sent + # down the connection before. We want to make sure that we send these down to + # the client regardless of filters so they find out about the state reset. + # + # We don't always have access to the state in a room after being state reset if + # no one else locally on the server is participating in the room so we patch + # these back in manually. + state_reset_out_of_room_id_set = { + room_id + for room_id in sync_room_map.keys() + if sync_room_map[room_id].event_id is None + and previous_connection_state.rooms.have_sent_room(room_id).status + != HaveSentRoomFlag.NEVER + } + + # Assemble a new sync room map but only with the `filtered_room_id_set` + return { + room_id: sync_room_map[room_id] + for room_id in filtered_room_id_set | state_reset_out_of_room_id_set + } + + @trace + async def filter_rooms_using_tables( + self, + user_id: str, + sync_room_map: Mapping[str, RoomsForUserSlidingSync], + previous_connection_state: PerConnectionState, + filters: SlidingSyncConfig.SlidingSyncList.Filters, + to_token: StreamToken, + dm_room_ids: AbstractSet[str], + ) -> Dict[str, RoomsForUserSlidingSync]: + """ + Filter rooms based on the sync request. + + Args: + user: User to filter rooms for + sync_room_map: Dictionary of room IDs to sort along with membership + information in the room at the time of `to_token`. + filters: Filters to apply + to_token: We filter based on the state of the room at this token + dm_room_ids: Set of room IDs which are DMs + room_tags: Mapping of room ID to tags + + Returns: + A filtered dictionary of room IDs along with membership information in the + room at the time of `to_token`. + """ + + filtered_room_id_set = set(sync_room_map.keys()) + + # Filter for Direct-Message (DM) rooms + if filters.is_dm is not None: + with start_active_span("filters.is_dm"): + if filters.is_dm: + # Intersect with the DM room set + filtered_room_id_set &= dm_room_ids + else: + # Remove DMs + filtered_room_id_set -= dm_room_ids + + if filters.spaces is not None: + with start_active_span("filters.spaces"): + raise NotImplementedError() + + # Filter for encrypted rooms + if filters.is_encrypted is not None: + filtered_room_id_set = { + room_id + for room_id in filtered_room_id_set + # Remove rooms if we can't figure out what the encryption status is + if sync_room_map[room_id].has_known_state + # Or remove if it doesn't match the filter + and sync_room_map[room_id].is_encrypted == filters.is_encrypted + } + + # Filter for rooms that the user has been invited to + if filters.is_invite is not None: + with start_active_span("filters.is_invite"): + # Make a copy so we don't run into an error: `Set changed size during + # iteration`, when we filter out and remove items + for room_id in filtered_room_id_set.copy(): + room_for_user = sync_room_map[room_id] + # If we're looking for invite rooms, filter out rooms that the user is + # not invited to and vice versa + if ( + filters.is_invite + and room_for_user.membership != Membership.INVITE + ) or ( + not filters.is_invite + and room_for_user.membership == Membership.INVITE + ): + filtered_room_id_set.remove(room_id) + + # Filter by room type (space vs room, etc). A room must match one of the types + # provided in the list. `None` is a valid type for rooms which do not have a + # room type. + if filters.room_types is not None or filters.not_room_types is not None: + with start_active_span("filters.room_types"): + # Make a copy so we don't run into an error: `Set changed size during + # iteration`, when we filter out and remove items + for room_id in filtered_room_id_set.copy(): + # Remove rooms if we can't figure out what room type it is + if not sync_room_map[room_id].has_known_state: + filtered_room_id_set.remove(room_id) + continue + + room_type = sync_room_map[room_id].room_type + + if ( + filters.room_types is not None + and room_type not in filters.room_types + ): + filtered_room_id_set.remove(room_id) + continue + + if ( + filters.not_room_types is not None + and room_type in filters.not_room_types + ): + filtered_room_id_set.remove(room_id) + continue + + if filters.room_name_like is not None: + with start_active_span("filters.room_name_like"): + # TODO: The room name is a bit more sensitive to leak than the + # create/encryption event. Maybe we should consider a better way to fetch + # historical state before implementing this. + # + # room_id_to_create_content = await self._bulk_get_partial_current_state_content_for_rooms( + # content_type="room_name", + # room_ids=filtered_room_id_set, + # to_token=to_token, + # sync_room_map=sync_room_map, + # room_id_to_stripped_state_map=room_id_to_stripped_state_map, + # ) + raise NotImplementedError() + + # Filter by room tags according to the users account data + if filters.tags is not None or filters.not_tags is not None: + with start_active_span("filters.tags"): + # Fetch the user tags for their rooms + room_tags = await self.store.get_tags_for_user(user_id) + room_id_to_tag_name_set: Dict[str, Set[str]] = { + room_id: set(tags.keys()) for room_id, tags in room_tags.items() + } + + if filters.tags is not None: + tags_set = set(filters.tags) + filtered_room_id_set = { + room_id + for room_id in filtered_room_id_set + # Remove rooms that don't have one of the tags in the filter + if room_id_to_tag_name_set.get(room_id, set()).intersection( + tags_set + ) + } + + if filters.not_tags is not None: + not_tags_set = set(filters.not_tags) + filtered_room_id_set = { + room_id + for room_id in filtered_room_id_set + # Remove rooms if they have any of the tags in the filter + if not room_id_to_tag_name_set.get(room_id, set()).intersection( + not_tags_set + ) + } + + # Keep rooms if the user has been state reset out of it but we previously sent + # down the connection before. We want to make sure that we send these down to + # the client regardless of filters so they find out about the state reset. + # + # We don't always have access to the state in a room after being state reset if + # no one else locally on the server is participating in the room so we patch + # these back in manually. + state_reset_out_of_room_id_set = { + room_id + for room_id in sync_room_map.keys() + if sync_room_map[room_id].event_id is None + and previous_connection_state.rooms.have_sent_room(room_id).status + != HaveSentRoomFlag.NEVER + } + + # Assemble a new sync room map but only with the `filtered_room_id_set` + return { + room_id: sync_room_map[room_id] + for room_id in filtered_room_id_set | state_reset_out_of_room_id_set + } + + @trace + async def sort_rooms( + self, + sync_room_map: Dict[str, RoomsForUserType], + to_token: StreamToken, + limit: Optional[int] = None, + ) -> List[RoomsForUserType]: + """ + Sort by `stream_ordering` of the last event that the user should see in the + room. `stream_ordering` is unique so we get a stable sort. + + If `limit` is specified then sort may return fewer entries, but will + always return at least the top N rooms. This is useful as we don't always + need to sort the full list, but are just interested in the top N. + + Args: + sync_room_map: Dictionary of room IDs to sort along with membership + information in the room at the time of `to_token`. + to_token: We sort based on the events in the room at this token (<= `to_token`) + limit: The number of rooms that we need to return from the top of the list. + + Returns: + A sorted list of room IDs by `stream_ordering` along with membership information. + """ + + # Assemble a map of room ID to the `stream_ordering` of the last activity that the + # user should see in the room (<= `to_token`) + last_activity_in_room_map: Dict[str, int] = {} + + # Same as above, except for positions that we know are in the event + # stream cache. + cached_positions: Dict[str, int] = {} + + earliest_cache_position = ( + self.store._events_stream_cache.get_earliest_known_position() + ) + + for room_id, room_for_user in sync_room_map.items(): + if room_for_user.membership == Membership.JOIN: + # For joined rooms check the stream change cache. + cached_position = ( + self.store._events_stream_cache.get_max_pos_of_last_change(room_id) + ) + if cached_position is not None: + cached_positions[room_id] = cached_position + else: + # If the user has left/been invited/knocked/been banned from a + # room, they shouldn't see anything past that point. + # + # FIXME: It's possible that people should see beyond this point + # in invited/knocked cases if for example the room has + # `invite`/`world_readable` history visibility, see + # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 + last_activity_in_room_map[room_id] = room_for_user.event_pos.stream + + # If the stream position is in range of the stream change cache + # we can include it. + if room_for_user.event_pos.stream > earliest_cache_position: + cached_positions[room_id] = room_for_user.event_pos.stream + + # If we are only asked for the top N rooms, and we have enough from + # looking in the stream change cache, then we can return early. This + # is because the cache must include all entries above + # `.get_earliest_known_position()`. + if limit is not None and len(cached_positions) >= limit: + # ... but first we need to handle the case where the cached max + # position is greater than the to_token, in which case we do + # actually query the DB. This should happen rarely, so can do it in + # a loop. + for room_id, position in list(cached_positions.items()): + if position > to_token.room_key.stream: + result = await self.store.get_last_event_pos_in_room_before_stream_ordering( + room_id, to_token.room_key + ) + if ( + result is not None + and result[1].stream > earliest_cache_position + ): + # We have a stream position in the cached range. + cached_positions[room_id] = result[1].stream + else: + # No position in the range, so we remove the entry. + cached_positions.pop(room_id) + + if limit is not None and len(cached_positions) >= limit: + return sorted( + ( + room + for room in sync_room_map.values() + if room.room_id in cached_positions + ), + # Sort by the last activity (stream_ordering) in the room + key=lambda room_info: cached_positions[room_info.room_id], + # We want descending order + reverse=True, + ) + + # For fully-joined rooms, we find the latest activity at/before the + # `to_token`. + joined_room_positions = ( + await self.store.bulk_get_last_event_pos_in_room_before_stream_ordering( + [ + room_id + for room_id, room_for_user in sync_room_map.items() + if room_for_user.membership == Membership.JOIN + ], + to_token.room_key, + ) + ) + + last_activity_in_room_map.update(joined_room_positions) + + return sorted( + sync_room_map.values(), + # Sort by the last activity (stream_ordering) in the room + key=lambda room_info: last_activity_in_room_map[room_info.room_id], + # We want descending order + reverse=True, + ) + + async def get_is_encrypted_for_room_at_token( + self, room_id: str, to_token: RoomStreamToken + ) -> bool: + """Get if the room is encrypted at the time.""" + + # Fetch the current encryption state + state_ids = await self.store.get_partial_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.RoomEncryption, "")]) + ) + encryption_event_id = state_ids.get((EventTypes.RoomEncryption, "")) + + # Now roll back the state by looking at the state deltas between + # to_token and now. + deltas = await self.store.get_current_state_deltas_for_room( + room_id, + from_token=to_token, + to_token=self.store.get_room_max_token(), + ) + + for delta in deltas: + if delta.event_type != EventTypes.RoomEncryption: + continue + + # Found the first change, we look at the previous event ID to get + # the state at the to token. + + if delta.prev_event_id is None: + # There is no prev event, so no encryption state event, so room is not encrypted + return False + + encryption_event_id = delta.prev_event_id + break + + # We didn't find an encryption state, room isn't encrypted + if encryption_event_id is None: + return False + + # We found encryption state, check if content has a non-null algorithm + encrypted_event = await self.store.get_event(encryption_event_id) + algorithm = encrypted_event.content.get(EventContentFields.ENCRYPTION_ALGORITHM) + + return algorithm is not None diff --git a/synapse/handlers/sliding_sync/store.py b/synapse/handlers/sliding_sync/store.py new file mode 100644
index 0000000000..d24fccf76f --- /dev/null +++ b/synapse/handlers/sliding_sync/store.py
@@ -0,0 +1,128 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2023 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +import logging +from typing import TYPE_CHECKING, Optional + +import attr + +from synapse.logging.opentracing import trace +from synapse.storage.databases.main import DataStore +from synapse.types import SlidingSyncStreamToken +from synapse.types.handlers.sliding_sync import ( + MutablePerConnectionState, + PerConnectionState, + SlidingSyncConfig, +) + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +@attr.s(auto_attribs=True) +class SlidingSyncConnectionStore: + """In-memory store of per-connection state, including what rooms we have + previously sent down a sliding sync connection. + + Note: This is NOT safe to run in a worker setup because connection positions will + point to different sets of rooms on different workers. e.g. for the same connection, + a connection position of 5 might have totally different states on worker A and + worker B. + + One complication that we need to deal with here is needing to handle requests being + resent, i.e. if we sent down a room in a response that the client received, we must + consider the room *not* sent when we get the request again. + + This is handled by using an integer "token", which is returned to the client + as part of the sync token. For each connection we store a mapping from + tokens to the room states, and create a new entry when we send down new + rooms. + + Note that for any given sliding sync connection we will only store a maximum + of two different tokens: the previous token from the request and a new token + sent in the response. When we receive a request with a given token, we then + clear out all other entries with a different token. + + Attributes: + _connections: Mapping from `(user_id, conn_id)` to mapping of `token` + to mapping of room ID to `HaveSentRoom`. + """ + + store: "DataStore" + + async def get_and_clear_connection_positions( + self, + sync_config: SlidingSyncConfig, + from_token: Optional[SlidingSyncStreamToken], + ) -> PerConnectionState: + """Fetch the per-connection state for the token. + + Raises: + SlidingSyncUnknownPosition if the connection_token is unknown + """ + # If this is our first request, there is no previous connection state to fetch out of the database + if from_token is None or from_token.connection_position == 0: + return PerConnectionState() + + conn_id = sync_config.conn_id or "" + + device_id = sync_config.requester.device_id + assert device_id is not None + + return await self.store.get_and_clear_connection_positions( + sync_config.user.to_string(), + device_id, + conn_id, + from_token.connection_position, + ) + + @trace + async def record_new_state( + self, + sync_config: SlidingSyncConfig, + from_token: Optional[SlidingSyncStreamToken], + new_connection_state: MutablePerConnectionState, + ) -> int: + """Record updated per-connection state, returning the connection + position associated with the new state. + If there are no changes to the state this may return the same token as + the existing per-connection state. + """ + if not new_connection_state.has_updates(): + if from_token is not None: + return from_token.connection_position + else: + return 0 + + # A from token with a zero connection position means there was no + # previously stored connection state, so we treat a zero the same as + # there being no previous position. + previous_connection_position = None + if from_token is not None and from_token.connection_position != 0: + previous_connection_position = from_token.connection_position + + conn_id = sync_config.conn_id or "" + + device_id = sync_config.requester.device_id + assert device_id is not None + + return await self.store.persist_per_connection_state( + sync_config.user.to_string(), + device_id, + conn_id, + previous_connection_position, + new_connection_state, + ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index ee74289b6c..2795b282e5 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py
@@ -33,17 +33,17 @@ from typing import ( Mapping, NoReturn, Optional, + Protocol, Set, ) from urllib.parse import urlencode import attr -from typing_extensions import Protocol from twisted.web.iweb import IRequest from twisted.web.server import Request -from synapse.api.constants import LoginType +from synapse.api.constants import LoginType, ProfileFields from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement from synapse.handlers.device import DeviceHandler @@ -81,8 +81,7 @@ class SsoIdentityProvider(Protocol): An Identity Provider, or IdP, is an external HTTP service which authenticates a user to say whether they should be allowed to log in, or perform a given action. - Synapse supports various implementations of IdPs, including OpenID Connect, SAML, - and CAS. + Synapse supports various implementations of IdPs, including OpenID Connect. The main entry point is `handle_redirect_request`, which should return a URI to redirect the user's browser to the IdP's authentication page. @@ -97,7 +96,7 @@ class SsoIdentityProvider(Protocol): def idp_id(self) -> str: """A unique identifier for this SSO provider - Eg, "saml", "cas", "github" + Eg. "github" """ @property @@ -157,7 +156,7 @@ class UserAttributes: class UsernameMappingSession: """Data we track about SSO sessions""" - # A unique identifier for this SSO provider, e.g. "oidc" or "saml". + # A unique identifier for this SSO provider, e.g. "oidc". auth_provider_id: str # An optional session ID from the IdP. @@ -351,7 +350,7 @@ class SsoHandler: Args: auth_provider_id: A unique identifier for this SSO provider, e.g. - "oidc" or "saml". + "oidc". remote_user_id: The user ID according to the remote IdP. This might be an e-mail address, a GUID, or some other form. It must be unique and immutable. @@ -418,7 +417,7 @@ class SsoHandler: Args: auth_provider_id: A unique identifier for this SSO provider, e.g. - "oidc" or "saml". + "oidc". remote_user_id: The unique identifier from the SSO provider. @@ -634,7 +633,7 @@ class SsoHandler: Args: auth_provider_id: A unique identifier for this SSO provider, e.g. - "oidc" or "saml". + "oidc". remote_user_id: The unique identifier from the SSO provider. @@ -704,7 +703,7 @@ class SsoHandler: including a non-empty localpart. auth_provider_id: A unique identifier for this SSO provider, e.g. - "oidc" or "saml". + "oidc". remote_user_id: The unique identifier from the SSO provider. @@ -813,9 +812,10 @@ class SsoHandler: # bail if user already has the same avatar profile = await self._profile_handler.get_profile(user_id) - if profile["avatar_url"] is not None: - server_name = profile["avatar_url"].split("/")[-2] - media_id = profile["avatar_url"].split("/")[-1] + if ProfileFields.AVATAR_URL in profile: + avatar_url_parts = profile[ProfileFields.AVATAR_URL].split("/") + server_name = avatar_url_parts[-2] + media_id = avatar_url_parts[-1] if self._is_mine_server_name(server_name): media = await self._media_repo.store.get_local_media(media_id) # type: ignore[has-type] if media is not None and upload_name == media.upload_name: @@ -855,12 +855,12 @@ class SsoHandler: Given an SSO ID, retrieve the user ID for it and complete UIA. Note that this requires that the user is mapped in the "user_external_ids" - table. This will be the case if they have ever logged in via SAML or OIDC in + table. This will be the case if they have ever logged in via OIDC in recentish synapse versions, but may not be for older users. Args: auth_provider_id: A unique identifier for this SSO provider, e.g. - "oidc" or "saml". + "oidc". remote_user_id: The unique identifier from the SSO provider. ui_auth_session_id: The ID of the user-interactive auth session. request: The request to complete. @@ -1184,16 +1184,16 @@ class SsoHandler: Args: auth_provider_id: A unique identifier for this SSO provider, e.g. - "oidc" or "saml". + "oidc". auth_provider_session_id: The session ID from the provider to logout expected_user_id: The user we're expecting to logout. If set, it will ignore sessions belonging to other users and log an error. """ # It is expected that this is the main process. - assert isinstance( - self._device_handler, DeviceHandler - ), "revoking SSO sessions can only be called on the main process" + assert isinstance(self._device_handler, DeviceHandler), ( + "revoking SSO sessions can only be called on the main process" + ) # Invalidate any running user-mapping sessions to_delete = [] @@ -1276,12 +1276,16 @@ def _check_attribute_requirement( return False # If the requirement is None, the attribute existing is enough. - if req.value is None: + if req.value is None and req.one_of is None: return True values = attributes[req.attribute] if req.value in values: return True + if req.one_of: + for value in req.one_of: + if value in values: + return True logger.info( "SSO attribute %s did not match required value '%s' (was '%s')", diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6af2eeb75f..c6f2c38d8d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -66,6 +66,7 @@ from synapse.logging.opentracing import ( from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.databases.main.stream import PaginateFunction +from synapse.storage.invite_rule import InviteRule from synapse.storage.roommember import MemberSummary from synapse.types import ( DeviceListUpdates, @@ -86,7 +87,7 @@ from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext -from synapse.util.metrics import Measure, measure_func +from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -143,6 +144,7 @@ class SyncConfig: filter_collection: FilterCollection is_guest: bool device_id: Optional[str] + use_state_after: bool @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -183,10 +185,7 @@ class JoinedSyncResult: to tell if room needs to be part of the sync result. """ return bool( - self.timeline - or self.state - or self.ephemeral - or self.account_data + self.timeline or self.state or self.ephemeral or self.account_data # nb the notification count does not, er, count: if there's nothing # else in the result, we don't need to send it. ) @@ -575,10 +574,10 @@ class SyncHandler: if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. - result: Union[SyncResult, E2eeSyncResult] = ( - await self.current_sync_for_user( - sync_config, sync_version, since_token, full_state=full_state - ) + result: Union[ + SyncResult, E2eeSyncResult + ] = await self.current_sync_for_user( + sync_config, sync_version, since_token, full_state=full_state ) else: # Otherwise, we wait for something to happen and report it to the user. @@ -673,10 +672,10 @@ class SyncHandler: # Go through the `/sync` v2 path if sync_version == SyncVersion.SYNC_V2: - sync_result: Union[SyncResult, E2eeSyncResult] = ( - await self.generate_sync_result( - sync_config, since_token, full_state - ) + sync_result: Union[ + SyncResult, E2eeSyncResult + ] = await self.generate_sync_result( + sync_config, since_token, full_state ) # Go through the MSC3575 Sliding Sync `/sync/e2ee` path elif sync_version == SyncVersion.E2EE_SYNC: @@ -909,7 +908,7 @@ class SyncHandler: # Use `stream_ordering` for updates else paginate_room_events_by_stream_ordering ) - events, end_key = await pagination_method( + events, end_key, limited = await pagination_method( room_id=room_id, # The bounds are reversed so we can paginate backwards # (from newer to older events) starting at to_bound. @@ -917,9 +916,7 @@ class SyncHandler: from_key=end_key, to_key=since_key, direction=Direction.BACKWARDS, - # We add one so we can determine if there are enough events to saturate - # the limit or not (see `limited`) - limit=load_limit + 1, + limit=load_limit, ) # We want to return the events in ascending order (the last event is the # most recent). @@ -974,9 +971,6 @@ class SyncHandler: loaded_recents.extend(recents) recents = loaded_recents - if len(events) <= load_limit: - limited = False - break max_repeat -= 1 if len(recents) > timeline_limit: @@ -1149,6 +1143,7 @@ class SyncHandler: since_token: Optional[StreamToken], end_token: StreamToken, full_state: bool, + joined: bool, ) -> MutableStateMap[EventBase]: """Works out the difference in state between the end of the previous sync and the start of the timeline. @@ -1163,6 +1158,7 @@ class SyncHandler: the point just after their leave event. full_state: Whether to force returning the full state. `lazy_load_members` still applies when `full_state` is `True`. + joined: whether the user is currently joined to the room Returns: The state to return in the sync response for the room. @@ -1238,11 +1234,12 @@ class SyncHandler: if full_state: state_ids = await self._compute_state_delta_for_full_sync( room_id, - sync_config.user, + sync_config, batch, end_token, members_to_fetch, timeline_state, + joined, ) else: # If this is an initial sync then full_state should be set, and @@ -1252,6 +1249,7 @@ class SyncHandler: state_ids = await self._compute_state_delta_for_incremental_sync( room_id, + sync_config, batch, since_token, end_token, @@ -1324,20 +1322,24 @@ class SyncHandler: async def _compute_state_delta_for_full_sync( self, room_id: str, - syncing_user: UserID, + sync_config: SyncConfig, batch: TimelineBatch, end_token: StreamToken, members_to_fetch: Optional[Set[str]], timeline_state: StateMap[str], + joined: bool, ) -> StateMap[str]: """Calculate the state events to be included in a full sync response. As with `_compute_state_delta_for_incremental_sync`, the result will include the membership events for the senders of each event in `members_to_fetch`. + Note that whether this returns the state at the start or the end of the + batch depends on `sync_config.use_state_after` (c.f. MSC4222). + Args: room_id: The room we are calculating for. - syncing_user: The user that is calling `/sync`. + sync_confg: The user that is calling `/sync`. batch: The timeline batch for the room that will be sent to the user. end_token: Token of the end of the current batch. Normally this will be the same as the global "now_token", but if the user has left the room, @@ -1346,10 +1348,11 @@ class SyncHandler: events in the timeline. timeline_state: The contribution to the room state from state events in `batch`. Only contains the last event for any given state key. + joined: whether the user is currently joined to the room Returns: A map from (type, state_key) to event_id, for each event that we believe - should be included in the `state` part of the sync response. + should be included in the `state` or `state_after` part of the sync response. """ if members_to_fetch is not None: # Lazy-loading of membership events is enabled. @@ -1367,7 +1370,7 @@ class SyncHandler: # is no guarantee that our membership will be in the auth events of # timeline events when the room is partial stated. state_filter = StateFilter.from_lazy_load_member_list( - members_to_fetch.union((syncing_user.to_string(),)) + members_to_fetch.union((sync_config.user.to_string(),)) ) # We are happy to use partial state to compute the `/sync` response. @@ -1381,6 +1384,61 @@ class SyncHandler: await_full_state = True lazy_load_members = False + # Check if we are wanting to return the state at the start or end of the + # timeline. If at the end we can just use the current state. + if sync_config.use_state_after: + # If we're getting the state at the end of the timeline, we can just + # use the current state of the room (and roll back any changes + # between when we fetched the current state and `end_token`). + # + # For rooms we're not joined to, there might be a very large number + # of deltas between `end_token` and "now", and so instead we fetch + # the state at the end of the timeline. + if joined: + state_ids = await self._state_storage_controller.get_current_state_ids( + room_id, + state_filter=state_filter, + await_full_state=await_full_state, + ) + + # Now roll back the state by looking at the state deltas between + # end_token and now. + deltas = await self.store.get_current_state_deltas_for_room( + room_id, + from_token=end_token.room_key, + to_token=self.store.get_room_max_token(), + ) + if deltas: + mutable_state_ids = dict(state_ids) + + # We iterate over the deltas backwards so that if there are + # multiple changes of the same type/state_key we'll + # correctly pick the earliest delta. + for delta in reversed(deltas): + if delta.prev_event_id: + mutable_state_ids[(delta.event_type, delta.state_key)] = ( + delta.prev_event_id + ) + elif (delta.event_type, delta.state_key) in mutable_state_ids: + mutable_state_ids.pop((delta.event_type, delta.state_key)) + + state_ids = mutable_state_ids + + return state_ids + + else: + # Just use state groups to get the state at the end of the + # timeline, i.e. the state at the leave/etc event. + state_at_timeline_end = ( + await self._state_storage_controller.get_state_ids_at( + room_id, + stream_position=end_token, + state_filter=state_filter, + await_full_state=await_full_state, + ) + ) + return state_at_timeline_end + state_at_timeline_end = await self._state_storage_controller.get_state_ids_at( room_id, stream_position=end_token, @@ -1413,6 +1471,7 @@ class SyncHandler: async def _compute_state_delta_for_incremental_sync( self, room_id: str, + sync_config: SyncConfig, batch: TimelineBatch, since_token: StreamToken, end_token: StreamToken, @@ -1427,8 +1486,12 @@ class SyncHandler: (`compute_state_delta`) is responsible for keeping track of which membership events we have already sent to the client, and hence ripping them out. + Note that whether this returns the state at the start or the end of the + batch depends on `sync_config.use_state_after` (c.f. MSC4222). + Args: room_id: The room we are calculating for. + sync_config batch: The timeline batch for the room that will be sent to the user. since_token: Token of the end of the previous batch. end_token: Token of the end of the current batch. Normally this will be @@ -1441,7 +1504,7 @@ class SyncHandler: Returns: A map from (type, state_key) to event_id, for each event that we believe - should be included in the `state` part of the sync response. + should be included in the `state` or `state_after` part of the sync response. """ if members_to_fetch is not None: # Lazy-loading is enabled. Only return the state that is needed. @@ -1453,6 +1516,51 @@ class SyncHandler: await_full_state = True lazy_load_members = False + # Check if we are wanting to return the state at the start or end of the + # timeline. If at the end we can just use the current state delta stream. + if sync_config.use_state_after: + delta_state_ids: MutableStateMap[str] = {} + + if members_to_fetch: + # We're lazy-loading, so the client might need some more member + # events to understand the events in this timeline. So we always + # fish out all the member events corresponding to the timeline + # here. The caller will then dedupe any redundant ones. + member_ids = await self._state_storage_controller.get_current_state_ids( + room_id=room_id, + state_filter=StateFilter.from_types( + (EventTypes.Member, member) for member in members_to_fetch + ), + await_full_state=await_full_state, + ) + delta_state_ids.update(member_ids) + + # We don't do LL filtering for incremental syncs - see + # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346 + # N.B. this slows down incr syncs as we are now processing way more + # state in the server than if we were LLing. + # + # i.e. we return all state deltas, including membership changes that + # we'd normally exclude due to LL. + deltas = await self.store.get_current_state_deltas_for_room( + room_id=room_id, + from_token=since_token.room_key, + to_token=end_token.room_key, + ) + for delta in deltas: + if delta.event_id is None: + # There was a state reset and this state entry is no longer + # present, but we have no way of informing the client about + # this, so we just skip it for now. + continue + + # Note that deltas are in stream ordering, so if there are + # multiple deltas for a given type/state_key we'll always pick + # the latest one. + delta_state_ids[(delta.event_type, delta.state_key)] = delta.event_id + + return delta_state_ids + # For a non-gappy sync if the events in the timeline are simply a linear # chain (i.e. no merging/branching of the graph), then we know the state # delta between the end of the previous sync and start of the new one is @@ -1488,13 +1596,16 @@ class SyncHandler: # timeline here. The caller will then dedupe any redundant # ones. - state_ids = await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, - # we only want members! - state_filter=StateFilter.from_types( - (EventTypes.Member, member) for member in members_to_fetch - ), - await_full_state=False, + state_ids = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[0].event_id, + # we only want members! + state_filter=StateFilter.from_types( + (EventTypes.Member, member) + for member in members_to_fetch + ), + await_full_state=False, + ) ) return state_ids @@ -1779,8 +1890,15 @@ class SyncHandler: ) if include_device_list_updates: - device_lists = await self._generate_sync_entry_for_device_list( - sync_result_builder, + # include_device_list_updates can only be True if we have a + # since token. + assert since_token is not None + + device_lists = await self._device_handler.generate_sync_entry_for_device_list( + user_id=user_id, + since_token=since_token, + now_token=sync_result_builder.now_token, + joined_room_ids=sync_result_builder.joined_room_ids, newly_joined_rooms=newly_joined_rooms, newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users, newly_left_rooms=newly_left_rooms, @@ -1892,8 +2010,14 @@ class SyncHandler: newly_left_users, ) = sync_result_builder.calculate_user_changes() - device_lists = await self._generate_sync_entry_for_device_list( - sync_result_builder, + # include_device_list_updates can only be True if we have a + # since token. + assert since_token is not None + device_lists = await self._device_handler.generate_sync_entry_for_device_list( + user_id=user_id, + since_token=since_token, + now_token=sync_result_builder.now_token, + joined_room_ids=sync_result_builder.joined_room_ids, newly_joined_rooms=newly_joined_rooms, newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users, newly_left_rooms=newly_left_rooms, @@ -2070,94 +2194,6 @@ class SyncHandler: return sync_result_builder - @measure_func("_generate_sync_entry_for_device_list") - async def _generate_sync_entry_for_device_list( - self, - sync_result_builder: "SyncResultBuilder", - newly_joined_rooms: AbstractSet[str], - newly_joined_or_invited_or_knocked_users: AbstractSet[str], - newly_left_rooms: AbstractSet[str], - newly_left_users: AbstractSet[str], - ) -> DeviceListUpdates: - """Generate the DeviceListUpdates section of sync - - Args: - sync_result_builder - newly_joined_rooms: Set of rooms user has joined since previous sync - newly_joined_or_invited_or_knocked_users: Set of users that have joined, - been invited to a room or are knocking on a room since - previous sync. - newly_left_rooms: Set of rooms user has left since previous sync - newly_left_users: Set of users that have left a room we're in since - previous sync - """ - - user_id = sync_result_builder.sync_config.user.to_string() - since_token = sync_result_builder.since_token - assert since_token is not None - - # Take a copy since these fields will be mutated later. - newly_joined_or_invited_or_knocked_users = set( - newly_joined_or_invited_or_knocked_users - ) - newly_left_users = set(newly_left_users) - - # We want to figure out what user IDs the client should refetch - # device keys for, and which users we aren't going to track changes - # for anymore. - # - # For the first step we check: - # a. if any users we share a room with have updated their devices, - # and - # b. we also check if we've joined any new rooms, or if a user has - # joined a room we're in. - # - # For the second step we just find any users we no longer share a - # room with by looking at all users that have left a room plus users - # that were in a room we've left. - - users_that_have_changed = set() - - joined_room_ids = sync_result_builder.joined_room_ids - - # Step 1a, check for changes in devices of users we share a room - # with - users_that_have_changed = ( - await self._device_handler.get_device_changes_in_shared_rooms( - user_id, - joined_room_ids, - from_token=since_token, - now_token=sync_result_builder.now_token, - ) - ) - - # Step 1b, check for newly joined rooms - for room_id in newly_joined_rooms: - joined_users = await self.store.get_users_in_room(room_id) - newly_joined_or_invited_or_knocked_users.update(joined_users) - - # TODO: Check that these users are actually new, i.e. either they - # weren't in the previous sync *or* they left and rejoined. - users_that_have_changed.update(newly_joined_or_invited_or_knocked_users) - - user_signatures_changed = await self.store.get_users_whose_signatures_changed( - user_id, since_token.device_list_key - ) - users_that_have_changed.update(user_signatures_changed) - - # Now find users that we no longer track - for room_id in newly_left_rooms: - left_users = await self.store.get_users_in_room(room_id) - newly_left_users.update(left_users) - - # Remove any users that we still share a room with. - left_users_rooms = await self.store.get_rooms_for_users(newly_left_users) - for user_id, entries in left_users_rooms.items(): - if any(rid in joined_room_ids for rid in entries): - newly_left_users.discard(user_id) - - return DeviceListUpdates(changed=users_that_have_changed, left=newly_left_users) - @trace async def _generate_sync_entry_for_to_device( self, sync_result_builder: "SyncResultBuilder" @@ -2241,18 +2277,18 @@ class SyncHandler: if push_rules_changed: global_account_data = dict(global_account_data) - global_account_data[AccountDataTypes.PUSH_RULES] = ( - await self._push_rules_handler.push_rules_for_user(sync_config.user) - ) + global_account_data[ + AccountDataTypes.PUSH_RULES + ] = await self._push_rules_handler.push_rules_for_user(sync_config.user) else: all_global_account_data = await self.store.get_global_account_data_for_user( user_id ) global_account_data = dict(all_global_account_data) - global_account_data[AccountDataTypes.PUSH_RULES] = ( - await self._push_rules_handler.push_rules_for_user(sync_config.user) - ) + global_account_data[ + AccountDataTypes.PUSH_RULES + ] = await self._push_rules_handler.push_rules_for_user(sync_config.user) account_data_for_user = ( await sync_config.filter_collection.filter_global_account_data( @@ -2514,6 +2550,7 @@ class SyncHandler: room_entries: List[RoomSyncResultBuilder] = [] invited: List[InvitedSyncResult] = [] knocked: List[KnockedSyncResult] = [] + invite_config = await self.store.get_invite_config_for_user(user_id) for room_id, events in mem_change_events_by_room_id.items(): # The body of this loop will add this room to at least one of the five lists # above. Things get messy if you've e.g. joined, left, joined then left the @@ -2596,7 +2633,11 @@ class SyncHandler: # Only bother if we're still currently invited should_invite = last_non_join.membership == Membership.INVITE if should_invite: - if last_non_join.sender not in ignored_users: + if ( + last_non_join.sender not in ignored_users + and invite_config.get_invite_rule(last_non_join.sender) + != InviteRule.IGNORE + ): invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join) if invite_room_sync: invited.append(invite_room_sync) @@ -2683,7 +2724,7 @@ class SyncHandler: newly_joined = room_id in newly_joined_rooms if room_entry: - events, start_key = room_entry + events, start_key, _ = room_entry # We want to return the events in ascending order (the last event is the # most recent). events.reverse() @@ -2751,6 +2792,7 @@ class SyncHandler: membership_list=Membership.LIST, excluded_rooms=sync_result_builder.excluded_room_ids, ) + invite_config = await self.store.get_invite_config_for_user(user_id) room_entries = [] invited = [] @@ -2776,6 +2818,8 @@ class SyncHandler: elif event.membership == Membership.INVITE: if event.sender in ignored_users: continue + if invite_config.get_invite_rule(event.sender) == InviteRule.IGNORE: + continue invite = await self.store.get_event(event.event_id) invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) elif event.membership == Membership.KNOCK: @@ -2947,6 +2991,7 @@ class SyncHandler: since_token, room_builder.end_token, full_state=full_state, + joined=room_builder.rtype == "joined", ) else: # An out of band room won't have any state changes. diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 32dca8c43b..477961d78c 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py
@@ -157,104 +157,6 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): ) -class _BaseThreepidAuthChecker: - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.store = hs.get_datastores().main - - async def _check_threepid(self, medium: str, authdict: dict) -> dict: - if "threepid_creds" not in authdict: - raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) - - threepid_creds = authdict["threepid_creds"] - - identity_handler = self.hs.get_identity_handler() - - logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) - - # msisdns are currently always verified via the IS - if medium == "msisdn": - if not self.hs.config.registration.account_threepid_delegate_msisdn: - raise SynapseError( - 400, "Phone number verification is not enabled on this homeserver" - ) - threepid = await identity_handler.threepid_from_creds( - self.hs.config.registration.account_threepid_delegate_msisdn, - threepid_creds, - ) - elif medium == "email": - if self.hs.config.email.can_verify_email: - threepid = None - row = await self.store.get_threepid_validation_session( - medium, - threepid_creds["client_secret"], - sid=threepid_creds["sid"], - validated=True, - ) - - if row: - threepid = { - "medium": row.medium, - "address": row.address, - "validated_at": row.validated_at, - } - - # Valid threepid returned, delete from the db - await self.store.delete_threepid_session(threepid_creds["sid"]) - else: - raise SynapseError( - 400, "Email address verification is not enabled on this homeserver" - ) - else: - # this can't happen! - raise AssertionError("Unrecognized threepid medium: %s" % (medium,)) - - if not threepid: - raise LoginError( - 401, "Unable to get validated threepid", errcode=Codes.UNAUTHORIZED - ) - - if threepid["medium"] != medium: - raise LoginError( - 401, - "Expecting threepid of type '%s', got '%s'" - % (medium, threepid["medium"]), - errcode=Codes.UNAUTHORIZED, - ) - - threepid["threepid_creds"] = authdict["threepid_creds"] - - return threepid - - -class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): - AUTH_TYPE = LoginType.EMAIL_IDENTITY - - def __init__(self, hs: "HomeServer"): - UserInteractiveAuthChecker.__init__(self, hs) - _BaseThreepidAuthChecker.__init__(self, hs) - - def is_enabled(self) -> bool: - return self.hs.config.email.can_verify_email - - async def check_auth(self, authdict: dict, clientip: str) -> Any: - return await self._check_threepid("email", authdict) - - -class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): - AUTH_TYPE = LoginType.MSISDN - - def __init__(self, hs: "HomeServer"): - UserInteractiveAuthChecker.__init__(self, hs) - _BaseThreepidAuthChecker.__init__(self, hs) - - def is_enabled(self) -> bool: - return bool(self.hs.config.registration.account_threepid_delegate_msisdn) - - async def check_auth(self, authdict: dict, clientip: str) -> Any: - return await self._check_threepid("msisdn", authdict) - - class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): AUTH_TYPE = LoginType.REGISTRATION_TOKEN @@ -263,7 +165,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): self.hs = hs self._enabled = bool( hs.config.registration.registration_requires_token - ) or bool(hs.config.registration.enable_registration_token_3pid_bypass) + ) self.store = hs.get_datastores().main def is_enabled(self) -> bool: @@ -325,8 +227,6 @@ INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [ DummyAuthChecker, TermsAuthChecker, RecaptchaAuthChecker, - EmailIdentityAuthChecker, - MsisdnAuthChecker, RegistrationTokenAuthChecker, ] """A list of UserInteractiveAuthChecker classes""" diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index a343637b82..33edef5f14 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py
@@ -26,7 +26,13 @@ from typing import TYPE_CHECKING, List, Optional, Set, Tuple from twisted.internet.interfaces import IDelayedCall import synapse.metrics -from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership +from synapse.api.constants import ( + EventTypes, + HistoryVisibility, + JoinRules, + Membership, + ProfileFields, +) from synapse.api.errors import Codes, SynapseError from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process @@ -102,6 +108,9 @@ class UserDirectoryHandler(StateDeltasHandler): self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.worker.should_update_user_directory self.search_all_users = hs.config.userdirectory.user_directory_search_all_users + self.exclude_remote_users = ( + hs.config.userdirectory.user_directory_exclude_remote_users + ) self.show_locked_users = hs.config.userdirectory.show_locked_users self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self._hs = hs @@ -161,7 +170,7 @@ class UserDirectoryHandler(StateDeltasHandler): non_spammy_users = [] for user in results["results"]: if not await self._spam_checker_module_callbacks.check_username_for_spam( - user + user, user_id ): non_spammy_users.append(user) results["results"] = non_spammy_users @@ -756,6 +765,10 @@ class UserDirectoryHandler(StateDeltasHandler): await self.store.update_profile_in_user_dir( user_id, - display_name=non_null_str_or_none(profile.get("displayname")), - avatar_url=non_null_str_or_none(profile.get("avatar_url")), + display_name=non_null_str_or_none( + profile.get(ProfileFields.DISPLAYNAME) + ), + avatar_url=non_null_str_or_none( + profile.get(ProfileFields.AVATAR_URL) + ), ) diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
index 7e578cf462..e58a416026 100644 --- a/synapse/handlers/worker_lock.py +++ b/synapse/handlers/worker_lock.py
@@ -19,6 +19,7 @@ # # +import logging import random from types import TracebackType from typing import ( @@ -183,7 +184,7 @@ class WorkerLocksHandler: return def _wake_all_locks( - locks: Collection[Union[WaitingLock, WaitingMultiLock]] + locks: Collection[Union[WaitingLock, WaitingMultiLock]], ) -> None: for lock in locks: deferred = lock.deferred @@ -269,6 +270,10 @@ class WaitingLock: def _get_next_retry_interval(self) -> float: next = self._retry_interval self._retry_interval = max(5, next * 2) + if self._retry_interval > 5 * 2 ^ 7: # ~10 minutes + logging.warning( + f"Lock timeout is getting excessive: {self._retry_interval}s. There may be a deadlock." + ) return next * random.uniform(0.9, 1.1) @@ -344,4 +349,8 @@ class WaitingMultiLock: def _get_next_retry_interval(self) -> float: next = self._retry_interval self._retry_interval = max(5, next * 2) + if self._retry_interval > 5 * 2 ^ 7: # ~10 minutes + logging.warning( + f"Lock timeout is getting excessive: {self._retry_interval}s. There may be a deadlock." + ) return next * random.uniform(0.9, 1.1) diff --git a/synapse/http/client.py b/synapse/http/client.py
index 56ad28eabf..84a510fb42 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -31,18 +31,17 @@ from typing import ( List, Mapping, Optional, + Protocol, Tuple, Union, ) import attr -import multipart import treq from canonicaljson import encode_canonical_json from netaddr import AddrFormatError, IPAddress, IPSet from prometheus_client import Counter -from typing_extensions import Protocol -from zope.interface import implementer, provider +from zope.interface import implementer from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE @@ -93,6 +92,20 @@ from synapse.util.async_helpers import timeout_deferred if TYPE_CHECKING: from synapse.server import HomeServer +# Support both import names for the `python-multipart` (PyPI) library, +# which renamed its package name from `multipart` to `python_multipart` +# in 0.0.13 (though supports the old import name for compatibility). +# Note that the `multipart` package name conflicts with `multipart` (PyPI) +# so we should prefer importing from `python_multipart` when possible. +try: + from python_multipart import MultipartParser + + if TYPE_CHECKING: + from python_multipart import multipart +except ImportError: + from multipart import MultipartParser # type: ignore[no-redef] + + logger = logging.getLogger(__name__) outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) @@ -212,7 +225,7 @@ class _IPBlockingResolver: recv.addressResolved(address) recv.resolutionComplete() - @provider(IResolutionReceiver) + @implementer(IResolutionReceiver) class EndpointReceiver: @staticmethod def resolutionBegan(resolutionInProgress: IHostResolution) -> None: @@ -226,8 +239,9 @@ class _IPBlockingResolver: def resolutionComplete() -> None: _callback() + endpoint_receiver_wrapper = EndpointReceiver() self._reactor.nameResolver.resolveHostName( - EndpointReceiver, hostname, portNumber=portNumber + endpoint_receiver_wrapper, hostname, portNumber=portNumber ) return recv @@ -1039,7 +1053,7 @@ class _MultipartParserProtocol(protocol.Protocol): self.deferred = deferred self.boundary = boundary self.max_length = max_length - self.parser = None + self.parser: Optional[MultipartParser] = None self.multipart_response = MultipartResponse() self.has_redirect = False self.in_json = False @@ -1057,11 +1071,11 @@ class _MultipartParserProtocol(protocol.Protocol): if not self.parser: def on_header_field(data: bytes, start: int, end: int) -> None: - if data[start:end] == b"Location": + if data[start:end].lower() == b"location": self.has_redirect = True - if data[start:end] == b"Content-Disposition": + if data[start:end].lower() == b"content-disposition": self.in_disposition = True - if data[start:end] == b"Content-Type": + if data[start:end].lower() == b"content-type": self.in_content_type = True def on_header_value(data: bytes, start: int, end: int) -> None: @@ -1088,7 +1102,6 @@ class _MultipartParserProtocol(protocol.Protocol): return # otherwise we are in the file part else: - logger.info("Writing multipart file data to stream") try: self.stream.write(data[start:end]) except Exception as e: @@ -1098,12 +1111,12 @@ class _MultipartParserProtocol(protocol.Protocol): self.deferred.errback() self.file_length += end - start - callbacks = { + callbacks: "multipart.MultipartCallbacks" = { "on_header_field": on_header_field, "on_header_value": on_header_value, "on_part_data": on_part_data, } - self.parser = multipart.MultipartParser(self.boundary, callbacks) + self.parser = MultipartParser(self.boundary, callbacks) self.total_length += len(incoming_data) if self.max_length is not None and self.total_length >= self.max_length: @@ -1114,7 +1127,7 @@ class _MultipartParserProtocol(protocol.Protocol): self.transport.abortConnection() try: - self.parser.write(incoming_data) # type: ignore[attr-defined] + self.parser.write(incoming_data) except Exception as e: logger.warning(f"Exception writing to multipart parser: {e}") self.deferred.errback() @@ -1314,6 +1327,5 @@ def is_unknown_endpoint( ) ) or ( # Older Synapses returned a 400 error. - e.code == 400 - and synapse_error.errcode == Codes.UNRECOGNIZED + e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED ) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 6fd75fd381..88bf98045c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -19,7 +19,6 @@ # # import abc -import cgi import codecs import logging import random @@ -35,6 +34,7 @@ from typing import ( Dict, Generic, List, + Literal, Optional, TextIO, Tuple, @@ -49,7 +49,6 @@ import treq from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json -from typing_extensions import Literal from twisted.internet import defer from twisted.internet.error import DNSLookupError @@ -426,9 +425,9 @@ class MatrixFederationHttpClient: ) else: proxy_authorization_secret = hs.config.worker.worker_replication_secret - assert ( - proxy_authorization_secret is not None - ), "`worker_replication_secret` must be set when using `outbound_federation_restricted_to` (used to authenticate requests across workers)" + assert proxy_authorization_secret is not None, ( + "`worker_replication_secret` must be set when using `outbound_federation_restricted_to` (used to authenticate requests across workers)" + ) federation_proxy_credentials = BearerProxyCredentials( proxy_authorization_secret.encode("ascii") ) @@ -792,7 +791,7 @@ class MatrixFederationHttpClient: url_str, _flatten_response_never_received(e), ) - body = None + body = b"" exc = HttpResponseException( response.code, response_phrase, body @@ -1756,8 +1755,10 @@ class MatrixFederationHttpClient: request.destination, str_url, ) + # We don't know how large the response will be upfront, so limit it to + # the `max_size` config value. length, headers, _, _ = await self._simple_http_client.get_file( - str_url, output_stream, expected_size + str_url, output_stream, max_size ) logger.info( @@ -1811,8 +1812,9 @@ def check_content_type_is(headers: Headers, expected_content_type: str) -> None: ) c_type = content_type_headers[0].decode("ascii") # only the first header - val, options = cgi.parse_header(c_type) - if val != expected_content_type: + # Extract the 'essence' of the mimetype, removing any parameter + c_type_parsed = c_type.split(";", 1)[0].strip() + if c_type_parsed != expected_content_type: raise RequestSendFailed( RuntimeError( f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'", diff --git a/synapse/http/proxy.py b/synapse/http/proxy.py
index 97aa429e7d..5cd990b0d0 100644 --- a/synapse/http/proxy.py +++ b/synapse/http/proxy.py
@@ -51,25 +51,17 @@ logger = logging.getLogger(__name__) # "Hop-by-hop" headers (as opposed to "end-to-end" headers) as defined by RFC2616 # section 13.5.1 and referenced in RFC9110 section 7.6.1. These are meant to only be # consumed by the immediate recipient and not be forwarded on. -HOP_BY_HOP_HEADERS = { - "Connection", - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "TE", - "Trailers", - "Transfer-Encoding", - "Upgrade", +HOP_BY_HOP_HEADERS_LOWERCASE = { + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", } - -if hasattr(Headers, "_canonicalNameCaps"): - # Twisted < 24.7.0rc1 - _canonicalHeaderName = Headers()._canonicalNameCaps # type: ignore[attr-defined] -else: - # Twisted >= 24.7.0rc1 - # But note that `_encodeName` still exists on prior versions, - # it just encodes differently - _canonicalHeaderName = Headers()._encodeName +assert all(header.lower() == header for header in HOP_BY_HOP_HEADERS_LOWERCASE) def parse_connection_header_value( @@ -92,12 +84,12 @@ def parse_connection_header_value( Returns: The set of header names that should not be copied over from the remote response. - The keys are capitalized in canonical capitalization. + The keys are lowercased. """ extra_headers_to_remove: Set[str] = set() if connection_header_value: extra_headers_to_remove = { - _canonicalHeaderName(connection_option.strip()).decode("ascii") + connection_option.decode("ascii").strip().lower() for connection_option in connection_header_value.split(b",") } @@ -194,7 +186,7 @@ class ProxyResource(_AsyncResource): # The `Connection` header also defines which headers should not be copied over. connection_header = response_headers.getRawHeaders(b"connection") - extra_headers_to_remove = parse_connection_header_value( + extra_headers_to_remove_lowercase = parse_connection_header_value( connection_header[0] if connection_header else None ) @@ -202,10 +194,10 @@ class ProxyResource(_AsyncResource): for k, v in response_headers.getAllRawHeaders(): # Do not copy over any hop-by-hop headers. These are meant to only be # consumed by the immediate recipient and not be forwarded on. - header_key = k.decode("ascii") + header_key_lowercase = k.decode("ascii").lower() if ( - header_key in HOP_BY_HOP_HEADERS - or header_key in extra_headers_to_remove + header_key_lowercase in HOP_BY_HOP_HEADERS_LOWERCASE + or header_key_lowercase in extra_headers_to_remove_lowercase ): continue diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index f80f67acc6..6817199035 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py
@@ -21,7 +21,7 @@ import logging import random import re -from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple +from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union from urllib.parse import urlparse from urllib.request import ( # type: ignore[attr-defined] getproxies_environment, @@ -150,6 +150,12 @@ class ProxyAgent(_AgentBase): http_proxy = proxies["http"].encode() if "http" in proxies else None https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None + logger.debug( + "Using proxy settings: http_proxy=%s, https_proxy=%s, no_proxy=%s", + http_proxy, + https_proxy, + no_proxy, + ) self.http_proxy_endpoint, self.http_proxy_creds = http_proxy_endpoint( http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs @@ -167,9 +173,9 @@ class ProxyAgent(_AgentBase): self._federation_proxy_endpoint: Optional[IStreamClientEndpoint] = None self._federation_proxy_credentials: Optional[ProxyCredentials] = None if federation_proxy_locations: - assert ( - federation_proxy_credentials is not None - ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + assert federation_proxy_credentials is not None, ( + "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + ) endpoints: List[IStreamClientEndpoint] = [] for federation_proxy_location in federation_proxy_locations: @@ -296,9 +302,9 @@ class ProxyAgent(_AgentBase): parsed_uri.scheme == b"matrix-federation" and self._federation_proxy_endpoint ): - assert ( - self._federation_proxy_credentials is not None - ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + assert self._federation_proxy_credentials is not None, ( + "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + ) # Set a Proxy-Authorization header if headers is None: @@ -351,7 +357,9 @@ def http_proxy_endpoint( proxy: Optional[bytes], reactor: IReactorCore, tls_options_factory: Optional[IPolicyForHTTPS], - **kwargs: object, + timeout: float = 30, + bindAddress: Optional[Union[bytes, str, tuple[Union[bytes, str], int]]] = None, + attemptDelay: Optional[float] = None, ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: """Parses an http proxy setting and returns an endpoint for the proxy @@ -382,12 +390,15 @@ def http_proxy_endpoint( # 3.9+) on scheme-less proxies, e.g. host:port. scheme, host, port, credentials = parse_proxy(proxy) - proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs) + proxy_endpoint = HostnameEndpoint( + reactor, host, port, timeout, bindAddress, attemptDelay + ) if scheme == b"https": if tls_options_factory: tls_options = tls_options_factory.creatorForNetloc(host, port) - proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) + wrapped_proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) + return wrapped_proxy_endpoint, credentials else: raise RuntimeError( f"No TLS options for a https connection via proxy {proxy!s}" diff --git a/synapse/http/replicationagent.py b/synapse/http/replicationagent.py
index ee8c707062..4eabbc8af9 100644 --- a/synapse/http/replicationagent.py +++ b/synapse/http/replicationagent.py
@@ -89,7 +89,7 @@ class ReplicationEndpointFactory: location_config.port, ) if scheme == "https": - endpoint = wrapClientTLS( + wrapped_endpoint = wrapClientTLS( # The 'port' argument below isn't actually used by the function self.context_factory.creatorForNetloc( location_config.host.encode("utf-8"), @@ -97,6 +97,8 @@ class ReplicationEndpointFactory: ), endpoint, ) + return wrapped_endpoint + return endpoint elif isinstance(location_config, InstanceUnixLocationConfig): return UNIXClientEndpoint(self.reactor, location_config.path) diff --git a/synapse/http/server.py b/synapse/http/server.py
index 0d0c610b28..bdd90d8a73 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -39,6 +39,7 @@ from typing import ( List, Optional, Pattern, + Protocol, Tuple, Union, ) @@ -46,7 +47,6 @@ from typing import ( import attr import jinja2 from canonicaljson import encode_canonical_json -from typing_extensions import Protocol from zope.interface import implementer from twisted.internet import defer, interfaces @@ -74,7 +74,6 @@ from synapse.api.errors import ( from synapse.config.homeserver import HomeServerConfig from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background from synapse.logging.opentracing import active_span, start_active_span, trace_servlet -from synapse.types import ISynapseReactor from synapse.util import json_encoder from synapse.util.caches import intern_dict from synapse.util.cancellation import is_function_cancellable @@ -142,7 +141,7 @@ def return_json_error( ) else: error_code = 500 - error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN} + error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN, "data": f.getTraceback()} logger.error( "Failed handle request via %r: %r", @@ -234,7 +233,7 @@ def return_html_error( def wrap_async_request_handler( - h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]] + h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]], ) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]: """Wraps an async request handler so that it calls request.processing. @@ -869,8 +868,7 @@ async def _async_write_json_to_request_in_thread( with start_active_span("encode_json_response"): span = active_span() - reactor: ISynapseReactor = request.reactor # type: ignore - json_str = await defer_to_thread(reactor, encode, span) + json_str = await defer_to_thread(request.reactor, encode, span) _write_bytes_to_request(request, json_str) @@ -923,15 +921,6 @@ def set_cors_headers(request: "SynapseRequest") -> None: b"Access-Control-Expose-Headers", b"Synapse-Trace-Id, Server, ETag", ) - elif request.experimental_cors_msc3886: - request.setHeader( - b"Access-Control-Allow-Headers", - b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match", - ) - request.setHeader( - b"Access-Control-Expose-Headers", - b"ETag, Location, X-Max-Bytes", - ) else: request.setHeader( b"Access-Control-Allow-Headers", diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 08b8ff7afd..47d8bd5eaf 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -28,6 +28,7 @@ from http import HTTPStatus from typing import ( TYPE_CHECKING, List, + Literal, Mapping, Optional, Sequence, @@ -37,19 +38,15 @@ from typing import ( overload, ) -from synapse._pydantic_compat import HAS_PYDANTIC_V2 - -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import BaseModel, MissingError, PydanticValueError, ValidationError - from pydantic.v1.error_wrappers import ErrorWrapper -else: - from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError - from pydantic.error_wrappers import ErrorWrapper - -from typing_extensions import Literal - from twisted.web.server import Request +from synapse._pydantic_compat import ( + BaseModel, + ErrorWrapper, + MissingError, + PydanticValueError, + ValidationError, +) from synapse.api.errors import Codes, SynapseError from synapse.http import redact_uri from synapse.http.server import HttpServer @@ -585,9 +582,9 @@ def parse_enum( is not one of those allowed values. """ # Assert the enum values are strings. - assert all( - isinstance(e.value, str) for e in E - ), "parse_enum only works with string values" + assert all(isinstance(e.value, str) for e in E), ( + "parse_enum only works with string values" + ) str_value = parse_string( request, name, diff --git a/synapse/http/site.py b/synapse/http/site.py
index af169ba51e..e83a4447b2 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -21,6 +21,7 @@ import contextlib import logging import time +from http import HTTPStatus from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union import attr @@ -94,7 +95,6 @@ class SynapseRequest(Request): self.reactor = site.reactor self._channel = channel # this is used by the tests self.start_time = 0.0 - self.experimental_cors_msc3886 = site.experimental_cors_msc3886 # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. @@ -140,6 +140,41 @@ class SynapseRequest(Request): self.synapse_site.site_tag, ) + # Twisted machinery: this method is called by the Channel once the full request has + # been received, to dispatch the request to a resource. + # + # We're patching Twisted to bail/abort early when we see someone trying to upload + # `multipart/form-data` so we can avoid Twisted parsing the entire request body into + # in-memory (specific problem of this specific `Content-Type`). This protects us + # from an attacker uploading something bigger than the available RAM and crashing + # the server with a `MemoryError`, or carefully block just enough resources to cause + # all other requests to fail. + # + # FIXME: This can be removed once we Twisted releases a fix and we update to a + # version that is patched + def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None: + if command == b"POST": + ctype = self.requestHeaders.getRawHeaders(b"content-type") + if ctype and b"multipart/form-data" in ctype[0]: + self.method, self.uri = command, path + self.clientproto = version + self.code = HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value + self.code_message = bytes( + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.phrase, "ascii" + ) + self.responseHeaders.setRawHeaders(b"content-length", [b"0"]) + + logger.warning( + "Aborting connection from %s because `content-type: multipart/form-data` is unsupported: %s %s", + self.client, + command, + path, + ) + self.write(b"") + self.loseConnection() + return + return super().requestReceived(command, path, version) + def handleContentChunk(self, data: bytes) -> None: # we should have a `content` by now. assert self.content, "handleContentChunk() called before gotLength()" @@ -658,7 +693,7 @@ class SynapseSite(ProxySite): ) self.site_tag = site_tag - self.reactor = reactor + self.reactor: ISynapseReactor = reactor assert config.http_options is not None proxied = config.http_options.x_forwarded @@ -666,10 +701,6 @@ class SynapseSite(ProxySite): request_id_header = config.http_options.request_id_header - self.experimental_cors_msc3886: bool = ( - config.http_options.experimental_cors_msc3886 - ) - def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index f047edee8e..ac34fa6525 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py
@@ -39,7 +39,7 @@ from twisted.internet.endpoints import ( ) from twisted.internet.interfaces import ( IPushProducer, - IReactorTCP, + IReactorTime, IStreamClientEndpoint, ) from twisted.internet.protocol import Factory, Protocol @@ -113,7 +113,7 @@ class RemoteHandler(logging.Handler): port: int, maximum_buffer: int = 1000, level: int = logging.NOTSET, - _reactor: Optional[IReactorTCP] = None, + _reactor: Optional[IReactorTime] = None, ): super().__init__(level=level) self.host = host diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 6a6afbfc0b..d9ff70b252 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py
@@ -22,6 +22,7 @@ """ Log formatters that output terse JSON. """ + import json import logging diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 4650b60962..3ef97f23c9 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py
@@ -20,7 +20,7 @@ # # -""" Thread-local-alike tracking of log contexts within synapse +"""Thread-local-alike tracking of log contexts within synapse This module provides objects and utilities for tracking contexts through synapse code, so that log lines can include a request identifier, and so that @@ -29,6 +29,7 @@ them. See doc/log_contexts.rst for details on how this works. """ + import logging import threading import typing @@ -36,8 +37,10 @@ import warnings from types import TracebackType from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, + Literal, Optional, Tuple, Type, @@ -47,7 +50,7 @@ from typing import ( ) import attr -from typing_extensions import Literal, ParamSpec +from typing_extensions import ParamSpec from twisted.internet import defer, threads from twisted.python.threadpool import ThreadPool @@ -751,7 +754,7 @@ def preserve_fn( f: Union[ Callable[P, R], Callable[P, Awaitable[R]], - ] + ], ) -> Callable[P, "defer.Deferred[R]"]: """Function decorator which wraps the function with run_in_background""" @@ -849,6 +852,45 @@ def run_in_background( return d +def run_coroutine_in_background( + coroutine: typing.Coroutine[Any, Any, R], +) -> "defer.Deferred[R]": + """Run the coroutine, ensuring that the current context is restored after + return from the function, and that the sentinel context is set once the + deferred returned by the function completes. + + Useful for wrapping coroutines that you don't yield or await on (for + instance because you want to pass it to deferred.gatherResults()). + + This is a special case of `run_in_background` where we can accept a + coroutine directly rather than a function. We can do this because coroutines + do not run until called, and so calling an async function without awaiting + cannot change the log contexts. + """ + + current = current_context() + d = defer.ensureDeferred(coroutine) + + # The function may have reset the context before returning, so + # we need to restore it now. + ctx = set_current_context(current) + + # The original context will be restored when the deferred + # completes, but there is nothing waiting for it, so it will + # get leaked into the reactor or some other function which + # wasn't expecting it. We therefore need to reset the context + # here. + # + # (If this feels asymmetric, consider it this way: we are + # effectively forking a new thread of execution. We are + # probably currently within a ``with LoggingContext()`` block, + # which is supposed to have a single entry and exit point. But + # by spawning off another deferred, we are effectively + # adding a new exit point.) + d.addBoth(_set_context_cb, ctx) + return d + + T = TypeVar("T") diff --git a/synapse/logging/filter.py b/synapse/logging/filter.py
index 11c27c63f2..16de488dbc 100644 --- a/synapse/logging/filter.py +++ b/synapse/logging/filter.py
@@ -19,8 +19,7 @@ # # import logging - -from typing_extensions import Literal +from typing import Literal class MetadataFilter(logging.Filter): diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 7a3c805cc5..d976e58e49 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -169,6 +169,7 @@ Gotchas than one caller? Will all of those calling functions have be in a context with an active span? """ + import contextlib import enum import inspect @@ -414,7 +415,7 @@ def ensure_active_span( """ def ensure_active_span_inner_1( - func: Callable[P, R] + func: Callable[P, R], ) -> Callable[P, Union[Optional[T], R]]: @wraps(func) def ensure_active_span_inner_2( @@ -700,7 +701,7 @@ def set_operation_name(operation_name: str) -> None: @only_if_tracing def force_tracing( - span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel + span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel, ) -> None: """Force sampling for the active/given span and its children. @@ -1032,13 +1033,13 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: def _wrapping_logic( _func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: - # We use `[1:]` to skip the `self` object reference and `start=1` to - # make the index line up with `argspec.args`. - # - # FIXME: We could update this to handle any type of function by ignoring the - # first argument only if it's named `self` or `cls`. This isn't fool-proof - # but handles the idiomatic cases. - for i, arg in enumerate(args[1:], start=1): + for i, arg in enumerate(args, start=0): + if argspec.args[i] in ("self", "cls"): + # Ignore `self` and `cls` values. Ideally we'd properly detect + # if we were wrapping a method, but that is really non-trivial + # and this is good enough. + continue + set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg)) set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) @@ -1093,9 +1094,10 @@ def trace_servlet( # Mypy seems to think that start_context.tag below can be Optional[str], but # that doesn't appear to be correct and works in practice. - request_tags[ - SynapseTags.REQUEST_TAG - ] = request.request_metrics.start_context.tag # type: ignore[assignment] + + request_tags[SynapseTags.REQUEST_TAG] = ( + request.request_metrics.start_context.tag # type: ignore[assignment] + ) # set the tags *after* the servlet completes, in case it decided to # prioritise the span (tags will get dropped on unprioritised spans) diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index 581e6d6411..feaadc4d87 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py
@@ -20,13 +20,10 @@ # import logging -from types import TracebackType -from typing import Optional, Type +from typing import Optional from opentracing import Scope, ScopeManager, Span -import twisted - from synapse.logging.context import ( LoggingContext, current_context, @@ -112,9 +109,6 @@ class _LogContextScope(Scope): """ A custom opentracing scope, associated with a LogContext - * filters out _DefGen_Return exceptions which arise from calling - `defer.returnValue` in Twisted code - * When the scope is closed, the logcontext's active scope is reset to None. and - if enter_logcontext was set - the logcontext is finished too. """ @@ -146,17 +140,6 @@ class _LogContextScope(Scope): self._finish_on_close = finish_on_close self._enter_logcontext = enter_logcontext - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - if exc_type == twisted.internet.defer._DefGen_Return: - # filter out defer.returnValue() calls - exc_type = value = traceback = None - super().__exit__(exc_type, value, traceback) - def __str__(self) -> str: return f"Scope<{self.span}>" diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 1b268ce4d4..2e48d2fdc7 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py
@@ -28,6 +28,7 @@ from types import TracebackType from typing import ( TYPE_CHECKING, Awaitable, + BinaryIO, Dict, Generator, List, @@ -37,21 +38,28 @@ from typing import ( ) import attr +from zope.interface import implementer +from twisted.internet import interfaces +from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender +from twisted.python.failure import Failure from twisted.web.server import Request from synapse.api.errors import Codes, cs_error from synapse.http.server import finish_request, respond_with_json from synapse.http.site import SynapseRequest -from synapse.logging.context import make_deferred_yieldable +from synapse.logging.context import ( + defer_to_threadpool, + make_deferred_yieldable, + run_in_background, +) from synapse.util import Clock +from synapse.util.async_helpers import DeferredEvent from synapse.util.stringutils import is_ascii if TYPE_CHECKING: - from synapse.storage.databases.main.media_repository import LocalMedia - + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -110,6 +118,9 @@ DEFAULT_MAX_TIMEOUT_MS = 20_000 # Maximum allowed timeout_ms for download and thumbnail requests MAXIMUM_ALLOWED_MAX_TIMEOUT_MS = 60_000 +# The ETag header value to use for immutable media. This can be anything. +_IMMUTABLE_ETAG = "1" + def respond_404(request: SynapseRequest) -> None: assert request.path is not None @@ -122,6 +133,7 @@ def respond_404(request: SynapseRequest) -> None: async def respond_with_file( + hs: "HomeServer", request: SynapseRequest, media_type: str, file_path: str, @@ -138,7 +150,7 @@ async def respond_with_file( add_file_headers(request, media_type, file_size, upload_name) with open(file_path, "rb") as f: - await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) + await ThreadedFileSender(hs).beginFileTransfer(f, request) finish_request(request) else: @@ -215,12 +227,7 @@ def add_file_headers( request.setHeader(b"Content-Disposition", disposition.encode("ascii")) - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our - # clients are smart enough to be happy with Cache-Control - request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") + _add_cache_headers(request) if file_size is not None: request.setHeader(b"Content-Length", b"%d" % (file_size,)) @@ -231,6 +238,26 @@ def add_file_headers( request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") +def _add_cache_headers(request: Request) -> None: + """Adds the appropriate cache headers to the response""" + + # Cache on the client for at least a day. + # + # We set this to "public,s-maxage=0,proxy-revalidate" to allow CDNs to cache + # the media, so long as they "revalidate" the media on every request. By + # revalidate, we mean send the request to Synapse with a `If-None-Match` + # header, to which Synapse can either respond with a 304 if the user is + # authenticated/authorized, or a 401/403 if they're not. + request.setHeader( + b"Cache-Control", b"public,max-age=86400,s-maxage=0,proxy-revalidate" + ) + + # Set an ETag header to allow requesters to use it in requests to check if + # the cache is still valid. Since media is immutable (though may be + # deleted), we just set this to a constant. + request.setHeader(b"ETag", _IMMUTABLE_ETAG) + + # separators as defined in RFC2616. SP and HT are handled separately. # see _can_encode_filename_as_token. _FILENAME_SEPARATOR_CHARS = { @@ -279,7 +306,9 @@ async def respond_with_multipart_responder( clock: Clock, request: SynapseRequest, responder: "Optional[Responder]", - media_info: "LocalMedia", + media_type: str, + media_length: Optional[int], + upload_name: Optional[str], ) -> None: """ Responds to requests originating from the federation media `/download` endpoint by @@ -303,7 +332,7 @@ async def respond_with_multipart_responder( ) return - if media_info.media_type.lower().split(";", 1)[0] in INLINE_CONTENT_TYPES: + if media_type.lower().split(";", 1)[0] in INLINE_CONTENT_TYPES: disposition = "inline" else: disposition = "attachment" @@ -311,33 +340,35 @@ async def respond_with_multipart_responder( def _quote(x: str) -> str: return urllib.parse.quote(x.encode("utf-8")) - if media_info.upload_name: - if _can_encode_filename_as_token(media_info.upload_name): + if upload_name: + if _can_encode_filename_as_token(upload_name): disposition = "%s; filename=%s" % ( disposition, - media_info.upload_name, + upload_name, ) else: disposition = "%s; filename*=utf-8''%s" % ( disposition, - _quote(media_info.upload_name), + _quote(upload_name), ) from synapse.media.media_storage import MultipartFileConsumer + _add_cache_headers(request) + # note that currently the json_object is just {}, this will change when linked media # is implemented multipart_consumer = MultipartFileConsumer( clock, request, - media_info.media_type, - {}, + media_type, + {}, # Note: if we change this we need to change the returned ETag. disposition, - media_info.media_length, + media_length, ) logger.debug("Responding to media request with responder %s", responder) - if media_info.media_length is not None: + if media_length is not None: content_length = multipart_consumer.content_length() assert content_length is not None request.setHeader(b"Content-Length", b"%d" % (content_length,)) @@ -408,6 +439,46 @@ async def respond_with_responder( finish_request(request) +def respond_with_304(request: SynapseRequest) -> None: + request.setResponseCode(304) + + # could alternatively use request.notifyFinish() and flip a flag when + # the Deferred fires, but since the flag is RIGHT THERE it seems like + # a waste. + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return None + + _add_cache_headers(request) + + request.finish() + + +def check_for_cached_entry_and_respond(request: SynapseRequest) -> bool: + """Check if the request has a conditional header that allows us to return a + 304 Not Modified response, and if it does, return a 304 response. + + This handles clients and intermediary proxies caching media. + This method assumes that the user has already been + authorised to request the media. + + Returns True if we have responded.""" + + # We've checked the user has access to the media, so we now check if it + # is a "conditional request" and we can just return a `304 Not Modified` + # response. Since media is immutable (though may be deleted), we just + # check this is the expected constant. + etag = request.getHeader("If-None-Match") + if etag == _IMMUTABLE_ETAG: + # Return a `304 Not modified`. + respond_with_304(request) + return True + + return False + + class Responder(ABC): """Represents a response that can be streamed to the requester. @@ -601,3 +672,151 @@ def _parseparam(s: bytes) -> Generator[bytes, None, None]: f = s[:end] yield f.strip() s = s[end:] + + +@implementer(interfaces.IPushProducer) +class ThreadedFileSender: + """ + A producer that sends the contents of a file to a consumer, reading from the + file on a thread. + + This works by having a loop in a threadpool repeatedly reading from the + file, until the consumer pauses the producer. There is then a loop in the + main thread that waits until the consumer resumes the producer and then + starts reading in the threadpool again. + + This is done to ensure that we're never waiting in the threadpool, as + otherwise its easy to starve it of threads. + """ + + # How much data to read in one go. + CHUNK_SIZE = 2**14 + + # How long we wait for the consumer to be ready again before aborting the + # read. + TIMEOUT_SECONDS = 90.0 + + def __init__(self, hs: "HomeServer") -> None: + self.reactor = hs.get_reactor() + self.thread_pool = hs.get_media_sender_thread_pool() + + self.file: Optional[BinaryIO] = None + self.deferred: "Deferred[None]" = Deferred() + self.consumer: Optional[interfaces.IConsumer] = None + + # Signals if the thread should keep reading/sending data. Set means + # continue, clear means pause. + self.wakeup_event = DeferredEvent(self.reactor) + + # Signals if the thread should terminate, e.g. because the consumer has + # gone away. + self.stop_writing = False + + def beginFileTransfer( + self, file: BinaryIO, consumer: interfaces.IConsumer + ) -> "Deferred[None]": + """ + Begin transferring a file + """ + self.file = file + self.consumer = consumer + + self.consumer.registerProducer(self, True) + + # We set the wakeup signal as we should start producing immediately. + self.wakeup_event.set() + run_in_background(self.start_read_loop) + + return make_deferred_yieldable(self.deferred) + + def resumeProducing(self) -> None: + """interfaces.IPushProducer""" + self.wakeup_event.set() + + def pauseProducing(self) -> None: + """interfaces.IPushProducer""" + self.wakeup_event.clear() + + def stopProducing(self) -> None: + """interfaces.IPushProducer""" + + # Unregister the consumer so we don't try and interact with it again. + if self.consumer: + self.consumer.unregisterProducer() + + self.consumer = None + + # Terminate the loop. + self.stop_writing = True + self.wakeup_event.set() + + if not self.deferred.called: + self.deferred.errback(Exception("Consumer asked us to stop producing")) + + async def start_read_loop(self) -> None: + """This is the loop that drives reading/writing""" + try: + while not self.stop_writing: + # Start the loop in the threadpool to read data. + more_data = await defer_to_threadpool( + self.reactor, self.thread_pool, self._on_thread_read_loop + ) + if not more_data: + # Reached EOF, we can just return. + return + + if not self.wakeup_event.is_set(): + ret = await self.wakeup_event.wait(self.TIMEOUT_SECONDS) + if not ret: + raise Exception("Timed out waiting to resume") + except Exception: + self._error(Failure()) + finally: + self._finish() + + def _on_thread_read_loop(self) -> bool: + """This is the loop that happens on a thread. + + Returns: + Whether there is more data to send. + """ + + while not self.stop_writing and self.wakeup_event.is_set(): + # The file should always have been set before we get here. + assert self.file is not None + + chunk = self.file.read(self.CHUNK_SIZE) + if not chunk: + return False + + self.reactor.callFromThread(self._write, chunk) + + return True + + def _write(self, chunk: bytes) -> None: + """Called from the thread to write a chunk of data""" + if self.consumer: + self.consumer.write(chunk) + + def _error(self, failure: Failure) -> None: + """Called when there was a fatal error""" + if self.consumer: + self.consumer.unregisterProducer() + self.consumer = None + + if not self.deferred.called: + self.deferred.errback(failure) + + def _finish(self) -> None: + """Called when we have finished writing (either on success or + failure).""" + if self.file: + self.file.close() + self.file = None + + if self.consumer: + self.consumer.unregisterProducer() + self.consumer = None + + if not self.deferred.called: + self.deferred.callback(None) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 8bc92305fe..18c5a8ecec 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py
@@ -52,13 +52,18 @@ from synapse.media._base import ( FileInfo, Responder, ThumbnailInfo, + check_for_cached_entry_and_respond, get_filename_from_headers, respond_404, respond_with_multipart_responder, respond_with_responder, ) from synapse.media.filepath import MediaFilePaths -from synapse.media.media_storage import MediaStorage +from synapse.media.media_storage import ( + MediaStorage, + SHA256TransparentIOReader, + SHA256TransparentIOWriter, +) from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer @@ -259,7 +264,7 @@ class MediaRepository: """ media = await self.store.get_local_media(media_id) if media is None: - raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND) + raise NotFoundError("Unknown media ID") if media.user_id != auth_user.to_string(): raise SynapseError( @@ -300,15 +305,26 @@ class MediaRepository: auth_user: The user_id of the uploader """ file_info = FileInfo(server_name=None, file_id=media_id) - fname = await self.media_storage.store_file(content, file_info) + sha256reader = SHA256TransparentIOReader(content) + # This implements all of IO as it has a passthrough + fname = await self.media_storage.store_file(sha256reader.wrap(), file_info) + sha256 = sha256reader.hexdigest() + should_quarantine = await self.store.get_is_hash_quarantined(sha256) logger.info("Stored local media in file %r", fname) + if should_quarantine: + logger.warn( + "Media has been automatically quarantined as it matched existing quarantined media" + ) + await self.store.update_local_media( media_id=media_id, media_type=media_type, upload_name=upload_name, media_length=content_length, user_id=auth_user, + sha256=sha256, + quarantined_by="system" if should_quarantine else None, ) try: @@ -341,11 +357,19 @@ class MediaRepository: media_id = random_string(24) file_info = FileInfo(server_name=None, file_id=media_id) - - fname = await self.media_storage.store_file(content, file_info) + # This implements all of IO as it has a passthrough + sha256reader = SHA256TransparentIOReader(content) + fname = await self.media_storage.store_file(sha256reader.wrap(), file_info) + sha256 = sha256reader.hexdigest() + should_quarantine = await self.store.get_is_hash_quarantined(sha256) logger.info("Stored local media in file %r", fname) + if should_quarantine: + logger.warn( + "Media has been automatically quarantined as it matched existing quarantined media" + ) + await self.store.store_local_media( media_id=media_id, media_type=media_type, @@ -353,6 +377,8 @@ class MediaRepository: upload_name=upload_name, media_length=content_length, user_id=auth_user, + sha256=sha256, + quarantined_by="system" if should_quarantine else None, ) try: @@ -459,6 +485,11 @@ class MediaRepository: self.mark_recently_accessed(None, media_id) + # Once we've checked auth we can return early if the media is cached on + # the client + if check_for_cached_entry_and_respond(request): + return + media_type = media_info.media_type if not media_type: media_type = "application/octet-stream" @@ -471,7 +502,7 @@ class MediaRepository: responder = await self.media_storage.fetch_media(file_info) if federation: await respond_with_multipart_responder( - self.clock, request, responder, media_info + self.clock, request, responder, media_type, media_length, upload_name ) else: await respond_with_responder( @@ -538,6 +569,17 @@ class MediaRepository: allow_authenticated, ) + # Check if the media is cached on the client, if so return 304. We need + # to do this after we have fetched remote media, as we need it to do the + # auth. + if check_for_cached_entry_and_respond(request): + # We always need to use the responder. + if responder: + with responder: + pass + + return + # We deliberately stream the file outside the lock if responder and media_info: upload_name = name if name else media_info.upload_name @@ -739,11 +781,13 @@ class MediaRepository: file_info = FileInfo(server_name=server_name, file_id=file_id) async with self.media_storage.store_into_file(file_info) as (f, fname): + sha256writer = SHA256TransparentIOWriter(f) try: length, headers = await self.client.download_media( server_name, media_id, - output_stream=f, + # This implements all of BinaryIO as it has a passthrough + output_stream=sha256writer.wrap(), max_size=self.max_upload_size, max_timeout_ms=max_timeout_ms, download_ratelimiter=download_ratelimiter, @@ -808,6 +852,7 @@ class MediaRepository: upload_name=upload_name, media_length=length, filesystem_id=file_id, + sha256=sha256writer.hexdigest(), ) logger.info("Stored remote media in file %r", fname) @@ -828,6 +873,7 @@ class MediaRepository: last_access_ts=time_now_ms, quarantined_by=None, authenticated=authenticated, + sha256=sha256writer.hexdigest(), ) async def _federation_download_remote_file( @@ -862,11 +908,13 @@ class MediaRepository: file_info = FileInfo(server_name=server_name, file_id=file_id) async with self.media_storage.store_into_file(file_info) as (f, fname): + sha256writer = SHA256TransparentIOWriter(f) try: res = await self.client.federation_download_media( server_name, media_id, - output_stream=f, + # This implements all of BinaryIO as it has a passthrough + output_stream=sha256writer.wrap(), max_size=self.max_upload_size, max_timeout_ms=max_timeout_ms, download_ratelimiter=download_ratelimiter, @@ -937,6 +985,7 @@ class MediaRepository: upload_name=upload_name, media_length=length, filesystem_id=file_id, + sha256=sha256writer.hexdigest(), ) logger.debug("Stored remote media in file %r", fname) @@ -957,6 +1006,7 @@ class MediaRepository: last_access_ts=time_now_ms, quarantined_by=None, authenticated=authenticated, + sha256=sha256writer.hexdigest(), ) def _get_thumbnail_requirements( @@ -1008,7 +1058,7 @@ class MediaRepository: t_method: str, t_type: str, url_cache: bool, - ) -> Optional[str]: + ) -> Optional[Tuple[str, FileInfo]]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) ) @@ -1070,7 +1120,7 @@ class MediaRepository: t_len, ) - return output_path + return output_path, file_info # Could not generate thumbnail. return None diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index 2a106bb0eb..afd33c02a1 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py
@@ -19,6 +19,7 @@ # # import contextlib +import hashlib import json import logging import os @@ -49,15 +50,11 @@ from zope.interface import implementer from twisted.internet import interfaces from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender from synapse.api.errors import NotFoundError -from synapse.logging.context import ( - defer_to_thread, - make_deferred_yieldable, - run_in_background, -) +from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.opentracing import start_active_span, trace, trace_with_opname +from synapse.media._base import ThreadedFileSender from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer @@ -74,6 +71,88 @@ logger = logging.getLogger(__name__) CRLF = b"\r\n" +class SHA256TransparentIOWriter: + """Will generate a SHA256 hash from a source stream transparently. + + Args: + source: Source stream. + """ + + def __init__(self, source: BinaryIO): + self._hash = hashlib.sha256() + self._source = source + + def write(self, buffer: Union[bytes, bytearray]) -> int: + """Wrapper for source.write() + + Args: + buffer + + Returns: + the value of source.write() + """ + res = self._source.write(buffer) + self._hash.update(buffer) + return res + + def hexdigest(self) -> str: + """The digest of the written or read value. + + Returns: + The digest in hex formaat. + """ + return self._hash.hexdigest() + + def wrap(self) -> BinaryIO: + # This class implements a subset the IO interface and passes through everything else via __getattr__ + return cast(BinaryIO, self) + + # Passthrough any other calls + def __getattr__(self, attr_name: str) -> Any: + return getattr(self._source, attr_name) + + +class SHA256TransparentIOReader: + """Will generate a SHA256 hash from a source stream transparently. + + Args: + source: Source IO stream. + """ + + def __init__(self, source: IO): + self._hash = hashlib.sha256() + self._source = source + + def read(self, n: int = -1) -> bytes: + """Wrapper for source.read() + + Args: + n + + Returns: + the value of source.read() + """ + bytes = self._source.read(n) + self._hash.update(bytes) + return bytes + + def hexdigest(self) -> str: + """The digest of the written or read value. + + Returns: + The digest in hex formaat. + """ + return self._hash.hexdigest() + + def wrap(self) -> IO: + # This class implements a subset the IO interface and passes through everything else via __getattr__ + return cast(IO, self) + + # Passthrough any other calls + def __getattr__(self, attr_name: str) -> Any: + return getattr(self._source, attr_name) + + class MediaStorage: """Responsible for storing/fetching files from local sources. @@ -111,7 +190,6 @@ class MediaStorage: Returns: the file path written to in the primary media store """ - async with self.store_into_file(file_info) as (f, fname): # Write to the main media repository await self.write_to_file(source, f) @@ -213,7 +291,7 @@ class MediaStorage: local_path = os.path.join(self.local_media_directory, path) if os.path.exists(local_path): logger.debug("responding with local file %s", local_path) - return FileResponder(open(local_path, "rb")) + return FileResponder(self.hs, open(local_path, "rb")) logger.debug("local file %s did not exist", local_path) for provider in self.storage_providers: @@ -336,13 +414,12 @@ class FileResponder(Responder): is closed when finished streaming. """ - def __init__(self, open_file: IO): + def __init__(self, hs: "HomeServer", open_file: BinaryIO): + self.hs = hs self.open_file = open_file def write_to_consumer(self, consumer: IConsumer) -> Deferred: - return make_deferred_yieldable( - FileSender().beginFileTransfer(self.open_file, consumer) - ) + return ThreadedFileSender(self.hs).beginFileTransfer(self.open_file, consumer) def __exit__( self, @@ -549,7 +626,7 @@ class MultipartFileConsumer: Calculate the content length of the multipart response in bytes. """ - if not self.length: + if self.length is None: return None # calculate length of json field and content-type, disposition headers json_field = json.dumps(self.json_field) diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py
index 06e5d27a53..300952025a 100644 --- a/synapse/media/storage_provider.py +++ b/synapse/media/storage_provider.py
@@ -145,6 +145,7 @@ class FileStorageProviderBackend(StorageProvider): def __init__(self, hs: "HomeServer", config: str): self.hs = hs + self.reactor = hs.get_reactor() self.cache_directory = hs.config.media.media_store_path self.base_directory = config @@ -165,7 +166,7 @@ class FileStorageProviderBackend(StorageProvider): shutil_copyfile: Callable[[str, str], str] = shutil.copyfile with start_active_span("shutil_copyfile"): await defer_to_thread( - self.hs.get_reactor(), + self.reactor, shutil_copyfile, primary_fname, backup_fname, @@ -177,7 +178,7 @@ class FileStorageProviderBackend(StorageProvider): backup_fname = os.path.join(self.base_directory, path) if os.path.isfile(backup_fname): - return FileResponder(open(backup_fname, "rb")) + return FileResponder(self.hs, open(backup_fname, "rb")) return None diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py
index ef6aa8ccf5..5d9afda322 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py
@@ -34,6 +34,7 @@ from synapse.logging.opentracing import trace from synapse.media._base import ( FileInfo, ThumbnailInfo, + check_for_cached_entry_and_respond, respond_404, respond_with_file, respond_with_multipart_responder, @@ -67,6 +68,11 @@ class ThumbnailError(Exception): class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} + # Which image formats we allow Pillow to open. + # This should intentionally be kept restrictive, because the decoder of any + # format in this list becomes part of our trusted computing base. + PILLOW_FORMATS = ("jpeg", "png", "webp", "gif") + @staticmethod def set_limits(max_image_pixels: int) -> None: Image.MAX_IMAGE_PIXELS = max_image_pixels @@ -76,7 +82,7 @@ class Thumbnailer: self._closed = False try: - self.image = Image.open(input_path) + self.image = Image.open(input_path, formats=self.PILLOW_FORMATS) except OSError as e: # If an error occurs opening the image, a thumbnail won't be able to # be generated. @@ -206,7 +212,7 @@ class Thumbnailer: def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO: output_bytes_io = BytesIO() fmt = self.FORMATS[output_type] - if fmt == "JPEG": + if fmt == "JPEG" or fmt == "PNG" and output_image.mode == "CMYK": output_image = output_image.convert("RGB") output_image.save(output_bytes_io, fmt, quality=80) return output_bytes_io @@ -259,6 +265,7 @@ class ThumbnailProvider: media_storage: MediaStorage, ): self.hs = hs + self.reactor = hs.get_reactor() self.media_repo = media_repo self.media_storage = media_storage self.store = hs.get_datastores().main @@ -288,6 +295,11 @@ class ThumbnailProvider: if media_info.authenticated: raise NotFoundError() + # Once we've checked auth we can return early if the media is cached on + # the client + if check_for_cached_entry_and_respond(request): + return + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) await self._select_and_respond_with_thumbnail( request, @@ -328,6 +340,11 @@ class ThumbnailProvider: if media_info.authenticated: raise NotFoundError() + # Once we've checked auth we can return early if the media is cached on + # the client + if check_for_cached_entry_and_respond(request): + return + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: t_w = info.width == desired_width @@ -347,7 +364,12 @@ class ThumbnailProvider: if responder: if for_federation: await respond_with_multipart_responder( - self.hs.get_clock(), request, responder, media_info + self.hs.get_clock(), + request, + responder, + info.type, + info.length, + None, ) return else: @@ -359,7 +381,7 @@ class ThumbnailProvider: logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. - file_path = await self.media_repo.generate_local_exact_thumbnail( + thumbnail_result = await self.media_repo.generate_local_exact_thumbnail( media_id, desired_width, desired_height, @@ -368,16 +390,21 @@ class ThumbnailProvider: url_cache=bool(media_info.url_cache), ) - if file_path: + if thumbnail_result: + file_path, file_info = thumbnail_result + assert file_info.thumbnail is not None + if for_federation: await respond_with_multipart_responder( self.hs.get_clock(), request, - FileResponder(open(file_path, "rb")), - media_info, + FileResponder(self.hs, open(file_path, "rb")), + file_info.thumbnail.type, + file_info.thumbnail.length, + None, ) else: - await respond_with_file(request, desired_type, file_path) + await respond_with_file(self.hs, request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") raise SynapseError(400, "Failed to generate thumbnail.") @@ -415,6 +442,10 @@ class ThumbnailProvider: respond_404(request) return + # Check if the media is cached on the client, if so return 304. + if check_for_cached_entry_and_respond(request): + return + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -455,7 +486,7 @@ class ThumbnailProvider: ) if file_path: - await respond_with_file(request, desired_type, file_path) + await respond_with_file(self.hs, request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") raise SynapseError(400, "Failed to generate thumbnail.") @@ -494,6 +525,10 @@ class ThumbnailProvider: if media_info.authenticated: raise NotFoundError() + # Check if the media is cached on the client, if so return 304. + if check_for_cached_entry_and_respond(request): + return + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -579,7 +614,12 @@ class ThumbnailProvider: if for_federation: assert media_info is not None await respond_with_multipart_responder( - self.hs.get_clock(), request, responder, media_info + self.hs.get_clock(), + request, + responder, + file_info.thumbnail.type, + file_info.thumbnail.length, + None, ) return else: @@ -633,7 +673,12 @@ class ThumbnailProvider: if for_federation: assert media_info is not None await respond_with_multipart_responder( - self.hs.get_clock(), request, responder, media_info + self.hs.get_clock(), + request, + responder, + file_info.thumbnail.type, + file_info.thumbnail.length, + None, ) else: await respond_with_responder( diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 2e65a04789..8ef2b3f0c0 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py
@@ -41,7 +41,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.media._base import FileInfo, get_filename_from_headers -from synapse.media.media_storage import MediaStorage +from synapse.media.media_storage import MediaStorage, SHA256TransparentIOWriter from synapse.media.oembed import OEmbedProvider from synapse.media.preview_html import decode_body, parse_html_to_open_graph from synapse.metrics.background_process_metrics import run_as_background_process @@ -593,17 +593,26 @@ class UrlPreviewer: file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) async with self.media_storage.store_into_file(file_info) as (f, fname): + sha256writer = SHA256TransparentIOWriter(f) if url.startswith("data:"): if not allow_data_urls: raise SynapseError( 500, "Previewing of data: URLs is forbidden", Codes.UNKNOWN ) - download_result = await self._parse_data_url(url, f) + download_result = await self._parse_data_url(url, sha256writer.wrap()) else: - download_result = await self._download_url(url, f) + download_result = await self._download_url(url, sha256writer.wrap()) try: + sha256 = sha256writer.hexdigest() + should_quarantine = await self.store.get_is_hash_quarantined(sha256) + + if should_quarantine: + logger.warn( + "Media has been automatically quarantined as it matched existing quarantined media" + ) + time_now_ms = self.clock.time_msec() await self.store.store_local_media( @@ -614,6 +623,8 @@ class UrlPreviewer: media_length=download_result.length, user_id=user, url_cache=url, + sha256=sha256, + quarantined_by="system" if should_quarantine else None, ) except Exception as e: diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 3051b623d0..9ce83da4ba 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py
@@ -427,16 +427,6 @@ build_info.labels( " ".join([platform.system(), platform.release()]), ).set(1) -# 3PID send info -threepid_send_requests = Histogram( - "synapse_threepid_send_requests_with_tries", - documentation="Number of requests for a 3pid token by try count. Note if" - " there is a request with try count of 4, then there would have been one" - " each for 1, 2 and 3", - buckets=(1, 2, 3, 4, 5, 10), - labelnames=("type", "reason"), -) - threadpool_total_threads = Gauge( "synapse_threadpool_total_threads", "Total number of threads currently in the threadpool", diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 19c92b02a0..49d0ff9fc1 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py
@@ -293,7 +293,7 @@ def wrap_as_background_process( """ def wrap_as_background_process_inner( - func: Callable[P, Awaitable[Optional[R]]] + func: Callable[P, Awaitable[Optional[R]]], ) -> Callable[P, "defer.Deferred[Optional[R]]"]: @wraps(func) def wrap_as_background_process_inner_2( diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
index bd25985686..321ff58083 100644 --- a/synapse/metrics/jemalloc.py +++ b/synapse/metrics/jemalloc.py
@@ -23,11 +23,10 @@ import ctypes import logging import os import re -from typing import Iterable, Optional, overload +from typing import Iterable, Literal, Optional, overload import attr from prometheus_client import REGISTRY, Metric -from typing_extensions import Literal from synapse.metrics import GaugeMetricFamily from synapse.metrics._types import Collector diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index f6bfd93d3c..e22d6f3ab7 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -import email.utils import logging from typing import ( TYPE_CHECKING, @@ -45,6 +44,7 @@ from twisted.internet.interfaces import IDelayedCall from twisted.web.resource import Resource from synapse.api import errors +from synapse.api.constants import ProfileFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState from synapse.config import ConfigError @@ -89,6 +89,14 @@ from synapse.module_api.callbacks.account_validity_callbacks import ( ON_USER_LOGIN_CALLBACK, ON_USER_REGISTRATION_CALLBACK, ) +from synapse.module_api.callbacks.media_repository_callbacks import ( + GET_MEDIA_CONFIG_FOR_USER_CALLBACK, + IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK, +) +from synapse.module_api.callbacks.ratelimit_callbacks import ( + GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK, + RatelimitOverride, +) from synapse.module_api.callbacks.spamchecker_callbacks import ( CHECK_EVENT_FOR_SPAM_CALLBACK, CHECK_LOGIN_FOR_SPAM_CALLBACK, @@ -101,21 +109,17 @@ from synapse.module_api.callbacks.spamchecker_callbacks import ( USER_MAY_INVITE_CALLBACK, USER_MAY_JOIN_ROOM_CALLBACK, USER_MAY_PUBLISH_ROOM_CALLBACK, - USER_MAY_SEND_3PID_INVITE_CALLBACK, + USER_MAY_SEND_STATE_EVENT_CALLBACK, SpamCheckerModuleApiCallbacks, ) from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( CHECK_CAN_DEACTIVATE_USER_CALLBACK, CHECK_CAN_SHUTDOWN_ROOM_CALLBACK, CHECK_EVENT_ALLOWED_CALLBACK, - CHECK_THREEPID_CAN_BE_INVITED_CALLBACK, CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK, - ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK, ON_CREATE_ROOM_CALLBACK, ON_NEW_EVENT_CALLBACK, ON_PROFILE_UPDATE_CALLBACK, - ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK, - ON_THREEPID_BIND_CALLBACK, ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK, ) from synapse.push.httppusher import HttpPusher @@ -188,6 +192,7 @@ __all__ = [ "ProfileInfo", "RoomAlias", "UserProfile", + "RatelimitOverride", ] logger = logging.getLogger(__name__) @@ -260,7 +265,6 @@ class ModuleApi: self._state = hs.get_state_handler() self._clock: Clock = hs.get_clock() self._registration_handler = hs.get_registration_handler() - self._send_email_handler = hs.get_send_email_handler() self._push_rules_handler = hs.get_push_rules_handler() self._pusherpool = hs.get_pusherpool() self._device_handler = hs.get_device_handler() @@ -269,20 +273,6 @@ class ModuleApi: self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled self._event_serializer = hs.get_event_client_serializer() - try: - app_name = self._hs.config.email.email_app_name - - self._from_string = self._hs.config.email.email_notif_from % { - "app": app_name - } - except (KeyError, TypeError): - # If substitution failed (which can happen if the string contains - # placeholders other than just "app", or if the type of the placeholder is - # not a string), fall back to the bare strings. - self._from_string = self._hs.config.email.email_notif_from - - self._raw_from = email.utils.parseaddr(self._from_string)[1] - # We expose these as properties below in order to attach a helpful docstring. self._http_client: SimpleHttpClient = hs.get_simple_http_client() self._public_room_list_manager = PublicRoomListManager(hs) @@ -304,12 +294,12 @@ class ModuleApi: ] = None, user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None, user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, - user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None, user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, user_may_create_room_alias: Optional[ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK ] = None, user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None, + user_may_send_state_event: Optional[USER_MAY_SEND_STATE_EVENT_CALLBACK] = None, check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None, check_registration_for_spam: Optional[ CHECK_REGISTRATION_FOR_SPAM_CALLBACK @@ -326,7 +316,6 @@ class ModuleApi: should_drop_federated_event=should_drop_federated_event, user_may_join_room=user_may_join_room, user_may_invite=user_may_invite, - user_may_send_3pid_invite=user_may_send_3pid_invite, user_may_create_room=user_may_create_room, user_may_create_room_alias=user_may_create_room_alias, user_may_publish_room=user_may_publish_room, @@ -334,6 +323,7 @@ class ModuleApi: check_registration_for_spam=check_registration_for_spam, check_media_file_for_spam=check_media_file_for_spam, check_login_for_spam=check_login_for_spam, + user_may_send_state_event=user_may_send_state_event, ) def register_account_validity_callbacks( @@ -359,14 +349,41 @@ class ModuleApi: on_legacy_admin_request=on_legacy_admin_request, ) + def register_ratelimit_callbacks( + self, + *, + get_ratelimit_override_for_user: Optional[ + GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK + ] = None, + ) -> None: + """Registers callbacks for ratelimit capabilities. + Added in Synapse v1.132.0. + """ + return self._callbacks.ratelimit.register_callbacks( + get_ratelimit_override_for_user=get_ratelimit_override_for_user, + ) + + def register_media_repository_callbacks( + self, + *, + get_media_config_for_user: Optional[GET_MEDIA_CONFIG_FOR_USER_CALLBACK] = None, + is_user_allowed_to_upload_media_of_size: Optional[ + IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK + ] = None, + ) -> None: + """Registers callbacks for media repository capabilities. + Added in Synapse v1.132.0. + """ + return self._callbacks.media_repository.register_callbacks( + get_media_config_for_user=get_media_config_for_user, + is_user_allowed_to_upload_media_of_size=is_user_allowed_to_upload_media_of_size, + ) + def register_third_party_rules_callbacks( self, *, check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, - check_threepid_can_be_invited: Optional[ - CHECK_THREEPID_CAN_BE_INVITED_CALLBACK - ] = None, check_visibility_can_be_modified: Optional[ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, @@ -377,13 +394,6 @@ class ModuleApi: on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = None, - on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, - on_add_user_third_party_identifier: Optional[ - ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK - ] = None, - on_remove_user_third_party_identifier: Optional[ - ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK - ] = None, ) -> None: """Registers callbacks for third party event rules capabilities. @@ -392,16 +402,12 @@ class ModuleApi: return self._callbacks.third_party_event_rules.register_third_party_rules_callbacks( check_event_allowed=check_event_allowed, on_create_room=on_create_room, - check_threepid_can_be_invited=check_threepid_can_be_invited, check_visibility_can_be_modified=check_visibility_can_be_modified, on_new_event=on_new_event, check_can_shutdown_room=check_can_shutdown_room, check_can_deactivate_user=check_can_deactivate_user, on_profile_update=on_profile_update, on_user_deactivation_status_changed=on_user_deactivation_status_changed, - on_threepid_bind=on_threepid_bind, - on_add_user_third_party_identifier=on_add_user_third_party_identifier, - on_remove_user_third_party_identifier=on_remove_user_third_party_identifier, ) def register_presence_router_callbacks( @@ -561,14 +567,6 @@ class ModuleApi: return self._hs.config.server.public_baseurl @property - def email_app_name(self) -> str: - """The application name configured in the homeserver's configuration. - - Added in Synapse v1.39.0. - """ - return self._hs.config.email.email_app_name - - @property def server_name(self) -> str: """The server name for the local homeserver. @@ -695,23 +693,6 @@ class ModuleApi: user_id = UserID.from_string(f"@{localpart}:{server_name}") return await self._store.get_profileinfo(user_id) - async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]: - """Look up the threepids (email addresses and phone numbers) associated with the - given Matrix user ID. - - Added in Synapse v1.39.0. - - Args: - user_id: The Matrix user ID to look up threepids for. - - Returns: - A list of threepids, each threepid being represented by a dictionary - containing a "medium" key which value is "email" for email addresses and - "msisdn" for phone numbers, and an "address" key which value is the - threepid's address. - """ - return [attr.asdict(t) for t in await self._store.user_get_threepids(user_id)] - def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]": """Check if user exists. @@ -893,9 +874,9 @@ class ModuleApi: Raises: synapse.api.errors.AuthError: the access token is invalid """ - assert isinstance( - self._device_handler, DeviceHandler - ), "invalidate_access_token can only be called on the main process" + assert isinstance(self._device_handler, DeviceHandler), ( + "invalidate_access_token can only be called on the main process" + ) # see if the access token corresponds to a device user_info = yield defer.ensureDeferred( @@ -1086,7 +1067,10 @@ class ModuleApi: content = {} # Set the profile if not already done by the module. - if "avatar_url" not in content or "displayname" not in content: + if ( + ProfileFields.AVATAR_URL not in content + or ProfileFields.DISPLAYNAME not in content + ): try: # Try to fetch the user's profile. profile = await self._hs.get_profile_handler().get_profile( @@ -1095,8 +1079,8 @@ class ModuleApi: except SynapseError as e: # If the profile couldn't be found, use default values. profile = { - "displayname": target_user_id.localpart, - "avatar_url": None, + ProfileFields.DISPLAYNAME: target_user_id.localpart, + ProfileFields.AVATAR_URL: None, } if e.code != 404: @@ -1109,11 +1093,9 @@ class ModuleApi: ) # Set the profile where it needs to be set. - if "avatar_url" not in content: - content["avatar_url"] = profile["avatar_url"] - - if "displayname" not in content: - content["displayname"] = profile["displayname"] + for field_name in [ProfileFields.AVATAR_URL, ProfileFields.DISPLAYNAME]: + if field_name not in content and field_name in profile: + content[field_name] = profile[field_name] event_id, _ = await self._hs.get_room_member_handler().update_membership( requester=requester, @@ -1398,31 +1380,6 @@ class ModuleApi: status[p.device_id] = sent return status - async def send_mail( - self, - recipient: str, - subject: str, - html: str, - text: str, - ) -> None: - """Send an email on behalf of the homeserver. - - Added in Synapse v1.39.0. - - Args: - recipient: The email address for the recipient. - subject: The email's subject. - html: The email's HTML content. - text: The email's text content. - """ - await self._send_email_handler.send_email( - email_address=recipient, - subject=subject, - app_name=self.email_app_name, - html=html, - text=text, - ) - def read_templates( self, filenames: List[str], @@ -1584,30 +1541,6 @@ class ModuleApi: """ await self._registration_handler.check_username(username) - async def store_remote_3pid_association( - self, user_id: str, medium: str, address: str, id_server: str - ) -> None: - """Stores an existing association between a user ID and a third-party identifier. - - The association must already exist on the remote identity server. - - Added in Synapse v1.56.0. - - Args: - user_id: The user ID that's been associated with the 3PID. - medium: The medium of the 3PID (current supported values are "msisdn" and - "email"). - address: The address of the 3PID. - id_server: The identity server the 3PID association has been registered on. - This should only be the domain (or IP address, optionally with the port - number) for the identity server. This will be used to reach out to the - identity server using HTTPS (unless specified otherwise by Synapse's - configuration) when attempting to unbind the third-party identifier. - - - """ - await self._store.add_user_bound_threepid(user_id, medium, address, id_server) - def check_push_rule_actions( self, actions: List[Union[str, Dict[str, str]]] ) -> None: @@ -1844,6 +1777,10 @@ class ModuleApi: deactivation=deactivation, ) + def get_current_time_msec(self) -> int: + """Returns the current server time in milliseconds.""" + return self._clock.time_msec() + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/module_api/callbacks/__init__.py b/synapse/module_api/callbacks/__init__.py
index c20d9543fb..16ef7a4b47 100644 --- a/synapse/module_api/callbacks/__init__.py +++ b/synapse/module_api/callbacks/__init__.py
@@ -27,6 +27,12 @@ if TYPE_CHECKING: from synapse.module_api.callbacks.account_validity_callbacks import ( AccountValidityModuleApiCallbacks, ) +from synapse.module_api.callbacks.media_repository_callbacks import ( + MediaRepositoryModuleApiCallbacks, +) +from synapse.module_api.callbacks.ratelimit_callbacks import ( + RatelimitModuleApiCallbacks, +) from synapse.module_api.callbacks.spamchecker_callbacks import ( SpamCheckerModuleApiCallbacks, ) @@ -38,5 +44,7 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( class ModuleApiCallbacks: def __init__(self, hs: "HomeServer") -> None: self.account_validity = AccountValidityModuleApiCallbacks() + self.media_repository = MediaRepositoryModuleApiCallbacks(hs) + self.ratelimit = RatelimitModuleApiCallbacks(hs) self.spam_checker = SpamCheckerModuleApiCallbacks(hs) self.third_party_event_rules = ThirdPartyEventRulesModuleApiCallbacks(hs) diff --git a/synapse/module_api/callbacks/media_repository_callbacks.py b/synapse/module_api/callbacks/media_repository_callbacks.py new file mode 100644
index 0000000000..6fa80a8eab --- /dev/null +++ b/synapse/module_api/callbacks/media_repository_callbacks.py
@@ -0,0 +1,76 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional + +from synapse.types import JsonDict +from synapse.util.async_helpers import delay_cancellation +from synapse.util.metrics import Measure + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +GET_MEDIA_CONFIG_FOR_USER_CALLBACK = Callable[[str], Awaitable[Optional[JsonDict]]] + +IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK = Callable[[str, int], Awaitable[bool]] + + +class MediaRepositoryModuleApiCallbacks: + def __init__(self, hs: "HomeServer") -> None: + self.clock = hs.get_clock() + self._get_media_config_for_user_callbacks: List[ + GET_MEDIA_CONFIG_FOR_USER_CALLBACK + ] = [] + self._is_user_allowed_to_upload_media_of_size_callbacks: List[ + IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK + ] = [] + + def register_callbacks( + self, + get_media_config_for_user: Optional[GET_MEDIA_CONFIG_FOR_USER_CALLBACK] = None, + is_user_allowed_to_upload_media_of_size: Optional[ + IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK + ] = None, + ) -> None: + """Register callbacks from module for each hook.""" + if get_media_config_for_user is not None: + self._get_media_config_for_user_callbacks.append(get_media_config_for_user) + + if is_user_allowed_to_upload_media_of_size is not None: + self._is_user_allowed_to_upload_media_of_size_callbacks.append( + is_user_allowed_to_upload_media_of_size + ) + + async def get_media_config_for_user(self, user_id: str) -> Optional[JsonDict]: + for callback in self._get_media_config_for_user_callbacks: + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): + res: Optional[JsonDict] = await delay_cancellation(callback(user_id)) + if res: + return res + + return None + + async def is_user_allowed_to_upload_media_of_size( + self, user_id: str, size: int + ) -> bool: + for callback in self._is_user_allowed_to_upload_media_of_size_callbacks: + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): + res: bool = await delay_cancellation(callback(user_id, size)) + if not res: + return res + + return True diff --git a/synapse/module_api/callbacks/ratelimit_callbacks.py b/synapse/module_api/callbacks/ratelimit_callbacks.py new file mode 100644
index 0000000000..64f9cc81e8 --- /dev/null +++ b/synapse/module_api/callbacks/ratelimit_callbacks.py
@@ -0,0 +1,74 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional + +import attr + +from synapse.util.async_helpers import delay_cancellation +from synapse.util.metrics import Measure + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +@attr.s(auto_attribs=True) +class RatelimitOverride: + """Represents a ratelimit being overridden.""" + + per_second: float + """The number of actions that can be performed in a second. `0.0` means that ratelimiting is disabled.""" + burst_count: int + """How many actions that can be performed before being limited.""" + + +GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK = Callable[ + [str, str], Awaitable[Optional[RatelimitOverride]] +] + + +class RatelimitModuleApiCallbacks: + def __init__(self, hs: "HomeServer") -> None: + self.clock = hs.get_clock() + self._get_ratelimit_override_for_user_callbacks: List[ + GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK + ] = [] + + def register_callbacks( + self, + get_ratelimit_override_for_user: Optional[ + GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK + ] = None, + ) -> None: + """Register callbacks from module for each hook.""" + if get_ratelimit_override_for_user is not None: + self._get_ratelimit_override_for_user_callbacks.append( + get_ratelimit_override_for_user + ) + + async def get_ratelimit_override_for_user( + self, user_id: str, limiter_name: str + ) -> Optional[RatelimitOverride]: + for callback in self._get_ratelimit_override_for_user_callbacks: + with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): + res: Optional[RatelimitOverride] = await delay_cancellation( + callback(user_id, limiter_name) + ) + if res: + return res + + return None diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py
index 17079ff781..30cab9eb7e 100644 --- a/synapse/module_api/callbacks/spamchecker_callbacks.py +++ b/synapse/module_api/callbacks/spamchecker_callbacks.py
@@ -19,8 +19,10 @@ # # +import functools import inspect import logging +from copy import deepcopy from typing import ( TYPE_CHECKING, Any, @@ -28,14 +30,13 @@ from typing import ( Callable, Collection, List, + Literal, Optional, Tuple, Union, + cast, ) -# `Literal` appears with Python 3.8. -from typing_extensions import Literal - import synapse from synapse.api.errors import Codes from synapse.logging.opentracing import trace @@ -104,24 +105,28 @@ USER_MAY_INVITE_CALLBACK = Callable[ ] ], ] -USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[ - [str, str, str, str], - Awaitable[ - Union[ - Literal["NOT_SPAM"], - Codes, - # Highly experimental, not officially part of the spamchecker API, may - # disappear without warning depending on the results of ongoing - # experiments. - # Use this to return additional information as part of an error. - Tuple[Codes, JsonDict], - # Deprecated - bool, - ] +USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE = Union[ + Literal["NOT_SPAM"], + Codes, + # Highly experimental, not officially part of the spamchecker API, may + # disappear without warning depending on the results of ongoing + # experiments. + # Use this to return additional information as part of an error. + Tuple[Codes, JsonDict], + # Deprecated + bool, +] +USER_MAY_CREATE_ROOM_CALLBACK = Union[ + Callable[ + [str, JsonDict], + Awaitable[USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE], + ], + Callable[ # Single argument variant for backwards compatibility + [str], Awaitable[USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE] ], ] -USER_MAY_CREATE_ROOM_CALLBACK = Callable[ - [str], +USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[ + [str, RoomAlias], Awaitable[ Union[ Literal["NOT_SPAM"], @@ -136,8 +141,8 @@ USER_MAY_CREATE_ROOM_CALLBACK = Callable[ ] ], ] -USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[ - [str, RoomAlias], +USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[ + [str, str], Awaitable[ Union[ Literal["NOT_SPAM"], @@ -152,8 +157,8 @@ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[ ] ], ] -USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[ - [str, str], +USER_MAY_SEND_STATE_EVENT_CALLBACK = Callable[ + [str, str, str, str, JsonDict], Awaitable[ Union[ Literal["NOT_SPAM"], @@ -163,12 +168,13 @@ USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[ # experiments. # Use this to return additional information as part of an error. Tuple[Codes, JsonDict], - # Deprecated - bool, ] ], ] -CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]] +CHECK_USERNAME_FOR_SPAM_CALLBACK = Union[ + Callable[[UserProfile], Awaitable[bool]], + Callable[[UserProfile, str], Awaitable[bool]], +] LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[ [ Optional[dict], @@ -293,6 +299,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None: "Bad signature for callback check_registration_for_spam", ) + @functools.wraps(wrapped_func) def run(*args: Any, **kwargs: Any) -> Awaitable: # Assertion required because mypy can't prove we won't change `f` # back to `None`. See @@ -324,10 +331,10 @@ class SpamCheckerModuleApiCallbacks: ] = [] self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] - self._user_may_send_3pid_invite_callbacks: List[ - USER_MAY_SEND_3PID_INVITE_CALLBACK - ] = [] self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = [] + self._user_may_send_state_event_callbacks: List[ + USER_MAY_SEND_STATE_EVENT_CALLBACK + ] = [] self._user_may_create_room_alias_callbacks: List[ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK ] = [] @@ -351,7 +358,6 @@ class SpamCheckerModuleApiCallbacks: ] = None, user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None, user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, - user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None, user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, user_may_create_room_alias: Optional[ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK @@ -363,6 +369,7 @@ class SpamCheckerModuleApiCallbacks: ] = None, check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None, + user_may_send_state_event: Optional[USER_MAY_SEND_STATE_EVENT_CALLBACK] = None, ) -> None: """Register callbacks from module for each hook.""" if check_event_for_spam is not None: @@ -379,14 +386,14 @@ class SpamCheckerModuleApiCallbacks: if user_may_invite is not None: self._user_may_invite_callbacks.append(user_may_invite) - if user_may_send_3pid_invite is not None: - self._user_may_send_3pid_invite_callbacks.append( - user_may_send_3pid_invite, - ) - if user_may_create_room is not None: self._user_may_create_room_callbacks.append(user_may_create_room) + if user_may_send_state_event is not None: + self._user_may_send_state_event_callbacks.append( + user_may_send_state_event, + ) + if user_may_create_room_alias is not None: self._user_may_create_room_alias_callbacks.append( user_may_create_room_alias, @@ -573,29 +580,42 @@ class SpamCheckerModuleApiCallbacks: # No spam-checker has rejected the request, let it pass. return self.NOT_SPAM - async def user_may_send_3pid_invite( - self, inviter_userid: str, medium: str, address: str, room_id: str + async def user_may_create_room( + self, userid: str, room_config: JsonDict ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]: - """Checks if a given user may invite a given threepid into the room - - Note that if the threepid is already associated with a Matrix user ID, Synapse - will call user_may_invite with said user ID instead. + """Checks if a given user may create a room Args: - inviter_userid: The user ID of the sender of the invitation - medium: The 3PID's medium (e.g. "email") - address: The 3PID's address (e.g. "alice@example.com") - room_id: The room ID - - Returns: - NOT_SPAM if the operation is permitted, Codes otherwise. + userid: The ID of the user attempting to create a room + room_config: The room creation configuration which is the body of the /createRoom request """ - for callback in self._user_may_send_3pid_invite_callbacks: + for callback in self._user_may_create_room_callbacks: with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): - res = await delay_cancellation( - callback(inviter_userid, medium, address, room_id) - ) - # Normalize return values to `Codes` or `"NOT_SPAM"`. + checker_args = inspect.signature(callback) + # Also ensure backwards compatibility with spam checker callbacks + # that don't expect the room_config argument. + if len(checker_args.parameters) == 2: + callback_with_requester_id = cast( + Callable[ + [str, JsonDict], + Awaitable[USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE], + ], + callback, + ) + # We make a copy of the config to ensure the spam checker cannot modify it. + res = await delay_cancellation( + callback_with_requester_id(userid, deepcopy(room_config)) + ) + else: + callback_without_requester_id = cast( + Callable[ + [str], Awaitable[USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE] + ], + callback, + ) + res = await delay_cancellation( + callback_without_requester_id(userid) + ) if res is True or res is self.NOT_SPAM: continue elif res is False: @@ -611,36 +631,38 @@ class SpamCheckerModuleApiCallbacks: return res else: logger.warning( - "Module returned invalid value, rejecting 3pid invite as spam" + "Module returned invalid value, rejecting room creation as spam" ) return synapse.api.errors.Codes.FORBIDDEN, {} return self.NOT_SPAM - async def user_may_create_room( - self, userid: str + async def user_may_send_state_event( + self, + user_id: str, + room_id: str, + event_type: str, + state_key: str, + content: JsonDict, ) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]: - """Checks if a given user may create a room - + """Checks if a given user may create a room with a given visibility Args: - userid: The ID of the user attempting to create a room + user_id: The ID of the user attempting to create a room + room_id: The ID of the room that the event will be sent to + event_type: The type of the state event + state_key: The state key of the state event + content: The content of the state event """ - for callback in self._user_may_create_room_callbacks: + for callback in self._user_may_send_state_event_callbacks: with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): - res = await delay_cancellation(callback(userid)) - if res is True or res is self.NOT_SPAM: + # We make a copy of the content to ensure that the spam checker cannot modify it. + res = await delay_cancellation( + callback(user_id, room_id, event_type, state_key, deepcopy(content)) + ) + if res is self.NOT_SPAM: continue - elif res is False: - return synapse.api.errors.Codes.FORBIDDEN, {} elif isinstance(res, synapse.api.errors.Codes): return res, {} - elif ( - isinstance(res, tuple) - and len(res) == 2 - and isinstance(res[0], synapse.api.errors.Codes) - and isinstance(res[1], dict) - ): - return res else: logger.warning( "Module returned invalid value, rejecting room creation as spam" @@ -716,7 +738,9 @@ class SpamCheckerModuleApiCallbacks: return self.NOT_SPAM - async def check_username_for_spam(self, user_profile: UserProfile) -> bool: + async def check_username_for_spam( + self, user_profile: UserProfile, requester_id: str + ) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. If the server considers a username spammy, then it will not be included in @@ -727,15 +751,33 @@ class SpamCheckerModuleApiCallbacks: * user_id * display_name * avatar_url + requester_id: The user ID of the user making the user directory search request. Returns: True if the user is spammy. """ for callback in self._check_username_for_spam_callbacks: with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): + checker_args = inspect.signature(callback) # Make a copy of the user profile object to ensure the spam checker cannot # modify it. - res = await delay_cancellation(callback(user_profile.copy())) + # Also ensure backwards compatibility with spam checker callbacks + # that don't expect the requester_id argument. + if len(checker_args.parameters) == 2: + callback_with_requester_id = cast( + Callable[[UserProfile, str], Awaitable[bool]], callback + ) + res = await delay_cancellation( + callback_with_requester_id(user_profile.copy(), requester_id) + ) + else: + callback_without_requester_id = cast( + Callable[[UserProfile], Awaitable[bool]], callback + ) + res = await delay_cancellation( + callback_without_requester_id(user_profile.copy()) + ) + if res: return True @@ -755,8 +797,8 @@ class SpamCheckerModuleApiCallbacks: username: The request user name, if any request_info: List of tuples of user agent and IP that were used during the registration process. - auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml", - "cas". If any. Note this does not include users registered + auth_provider_id: The SSO IdP the user used, e.g "oidc". + If any. Note this does not include users registered via a password provider. Returns: @@ -844,8 +886,8 @@ class SpamCheckerModuleApiCallbacks: user_id: The request user ID request_info: List of tuples of user agent and IP that were used during the registration process. - auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml", - "cas". If any. Note this does not include users registered + auth_provider_id: The SSO IdP the user used, e.g "oidc". + If any. Note this does not include users registered via a password provider. Returns: diff --git a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
index 9f7a04372d..13508cc582 100644 --- a/synapse/module_api/callbacks/third_party_event_rules_callbacks.py +++ b/synapse/module_api/callbacks/third_party_event_rules_callbacks.py
@@ -40,9 +40,6 @@ CHECK_EVENT_ALLOWED_CALLBACK = Callable[ [EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]] ] ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable] -CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[ - [str, str, StateMap[EventBase]], Awaitable[bool] -] CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ [str, StateMap[EventBase], str], Awaitable[bool] ] @@ -51,9 +48,6 @@ CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[Optional[str], str], Awaitable[bool CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] -ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable] -ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable] -ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable] def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: @@ -73,7 +67,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: third_party_event_rules_methods = { "check_event_allowed", "on_create_room", - "check_threepid_can_be_invited", "check_visibility_can_be_modified", } @@ -161,9 +154,6 @@ class ThirdPartyEventRulesModuleApiCallbacks: self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] - self._check_threepid_can_be_invited_callbacks: List[ - CHECK_THREEPID_CAN_BE_INVITED_CALLBACK - ] = [] self._check_visibility_can_be_modified_callbacks: List[ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = [] @@ -178,21 +168,11 @@ class ThirdPartyEventRulesModuleApiCallbacks: self._on_user_deactivation_status_changed_callbacks: List[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = [] - self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = [] - self._on_add_user_third_party_identifier_callbacks: List[ - ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK - ] = [] - self._on_remove_user_third_party_identifier_callbacks: List[ - ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK - ] = [] def register_third_party_rules_callbacks( self, check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, - check_threepid_can_be_invited: Optional[ - CHECK_THREEPID_CAN_BE_INVITED_CALLBACK - ] = None, check_visibility_can_be_modified: Optional[ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, @@ -202,14 +182,7 @@ class ThirdPartyEventRulesModuleApiCallbacks: on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK - ] = None, - on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, - on_add_user_third_party_identifier: Optional[ - ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK - ] = None, - on_remove_user_third_party_identifier: Optional[ - ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK - ] = None, + ] = None ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -218,11 +191,6 @@ class ThirdPartyEventRulesModuleApiCallbacks: if on_create_room is not None: self._on_create_room_callbacks.append(on_create_room) - if check_threepid_can_be_invited is not None: - self._check_threepid_can_be_invited_callbacks.append( - check_threepid_can_be_invited, - ) - if check_visibility_can_be_modified is not None: self._check_visibility_can_be_modified_callbacks.append( check_visibility_can_be_modified, @@ -236,6 +204,7 @@ class ThirdPartyEventRulesModuleApiCallbacks: if check_can_deactivate_user is not None: self._check_can_deactivate_user_callbacks.append(check_can_deactivate_user) + if on_profile_update is not None: self._on_profile_update_callbacks.append(on_profile_update) @@ -244,19 +213,6 @@ class ThirdPartyEventRulesModuleApiCallbacks: on_user_deactivation_status_changed, ) - if on_threepid_bind is not None: - self._on_threepid_bind_callbacks.append(on_threepid_bind) - - if on_add_user_third_party_identifier is not None: - self._on_add_user_third_party_identifier_callbacks.append( - on_add_user_third_party_identifier - ) - - if on_remove_user_third_party_identifier is not None: - self._on_remove_user_third_party_identifier_callbacks.append( - on_remove_user_third_party_identifier - ) - async def check_event_allowed( self, event: EventBase, @@ -349,39 +305,6 @@ class ThirdPartyEventRulesModuleApiCallbacks: raise e - async def check_threepid_can_be_invited( - self, medium: str, address: str, room_id: str - ) -> bool: - """Check if a provided 3PID can be invited in the given room. - - Args: - medium: The 3PID's medium. - address: The 3PID's address. - room_id: The room we want to invite the threepid to. - - Returns: - True if the 3PID can be invited, False if not. - """ - # Bail out early without hitting the store if we don't have any callbacks to run. - if len(self._check_threepid_can_be_invited_callbacks) == 0: - return True - - state_events = await self._storage_controllers.state.get_current_state(room_id) - - for callback in self._check_threepid_can_be_invited_callbacks: - try: - threepid_can_be_invited = await delay_cancellation( - callback(medium, address, state_events) - ) - if threepid_can_be_invited is False: - return False - except CancelledError: - raise - except Exception as e: - logger.warning("Failed to run module API callback %s: %s", callback, e) - - return True - async def check_visibility_can_be_modified( self, room_id: str, new_visibility: str ) -> bool: @@ -533,67 +456,3 @@ class ThirdPartyEventRulesModuleApiCallbacks: logger.exception( "Failed to run module API callback %s: %s", callback, e ) - - async def on_threepid_bind(self, user_id: str, medium: str, address: str) -> None: - """Called after a threepid association has been verified and stored. - - Note that this callback is called when an association is created on the - local homeserver, not when it's created on an identity server (and then kept track - of so that it can be unbound on the same IS later on). - - THIS MODULE CALLBACK METHOD HAS BEEN DEPRECATED. Please use the - `on_add_user_third_party_identifier` callback method instead. - - Args: - user_id: the user being associated with the threepid. - medium: the threepid's medium. - address: the threepid's address. - """ - for callback in self._on_threepid_bind_callbacks: - try: - await callback(user_id, medium, address) - except Exception as e: - logger.exception( - "Failed to run module API callback %s: %s", callback, e - ) - - async def on_add_user_third_party_identifier( - self, user_id: str, medium: str, address: str - ) -> None: - """Called when an association between a user's Matrix ID and a third-party ID - (email, phone number) has successfully been registered on the homeserver. - - Args: - user_id: The User ID included in the association. - medium: The medium of the third-party ID (email, msisdn). - address: The address of the third-party ID (i.e. an email address). - """ - for callback in self._on_add_user_third_party_identifier_callbacks: - try: - await callback(user_id, medium, address) - except Exception as e: - logger.exception( - "Failed to run module API callback %s: %s", callback, e - ) - - async def on_remove_user_third_party_identifier( - self, user_id: str, medium: str, address: str - ) -> None: - """Called when an association between a user's Matrix ID and a third-party ID - (email, phone number) has been successfully removed on the homeserver. - - This is called *after* any known bindings on identity servers for this - association have been removed. - - Args: - user_id: The User ID included in the removed association. - medium: The medium of the third-party ID (email, msisdn). - address: The address of the third-party ID (i.e. an email address). - """ - for callback in self._on_remove_user_third_party_identifier_callbacks: - try: - await callback(user_id, medium, address) - except Exception as e: - logger.exception( - "Failed to run module API callback %s: %s", callback, e - ) diff --git a/synapse/notifier.py b/synapse/notifier.py
index 7a2b54036c..6190432b87 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py
@@ -41,6 +41,7 @@ import attr from prometheus_client import Counter from twisted.internet import defer +from twisted.internet.defer import Deferred from synapse.api.constants import EduTypes, EventTypes, HistoryVisibility, Membership from synapse.api.errors import AuthError @@ -52,6 +53,7 @@ from synapse.logging.opentracing import log_kv, start_active_span from synapse.metrics import LaterGauge from synapse.streams.config import PaginationConfig from synapse.types import ( + ISynapseReactor, JsonDict, MultiWriterStreamToken, PersistedEventPosition, @@ -61,8 +63,10 @@ from synapse.types import ( StreamToken, UserID, ) -from synapse.util.async_helpers import ObservableDeferred, timeout_deferred -from synapse.util.metrics import Measure +from synapse.util.async_helpers import ( + timeout_deferred, +) +from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -89,18 +93,6 @@ def count(func: Callable[[T], bool], it: Iterable[T]) -> int: return n -class _NotificationListener: - """This represents a single client connection to the events stream. - The events stream handler will have yielded to the deferred, so to - notify the handler it is sufficient to resolve the deferred. - """ - - __slots__ = ["deferred"] - - def __init__(self, deferred: "defer.Deferred"): - self.deferred = deferred - - class _NotifierUserStream: """This represents a user connected to the event stream. It tracks the most recent stream token for that user. @@ -113,59 +105,49 @@ class _NotifierUserStream: def __init__( self, + reactor: ISynapseReactor, user_id: str, rooms: StrCollection, current_token: StreamToken, time_now_ms: int, ): + self.reactor = reactor self.user_id = user_id self.rooms = set(rooms) - self.current_token = current_token # The last token for which we should wake up any streams that have a # token that comes before it. This gets updated every time we get poked. # We start it at the current token since if we get any streams # that have a token from before we have no idea whether they should be # woken up or not, so lets just wake them up. - self.last_notified_token = current_token + self.current_token = current_token self.last_notified_ms = time_now_ms - self.notify_deferred: ObservableDeferred[StreamToken] = ObservableDeferred( - defer.Deferred() - ) + # Set of listeners that we need to wake up when there has been a change. + self.listeners: Set[Deferred[StreamToken]] = set() - def notify( + def update_and_fetch_deferreds( self, - stream_key: StreamKeyType, - stream_id: Union[int, RoomStreamToken, MultiWriterStreamToken], + current_token: StreamToken, time_now_ms: int, - ) -> None: - """Notify any listeners for this user of a new event from an - event source. + ) -> Collection["Deferred[StreamToken]"]: + """Update the stream for this user because of a new event from an + event source, and return the set of deferreds to wake up. + Args: - stream_key: The stream the event came from. - stream_id: The new id for the stream the event came from. + current_token: The new current token. time_now_ms: The current time in milliseconds. + + Returns: + The set of deferreds that need to be called. """ - self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) - self.last_notified_token = self.current_token + self.current_token = current_token self.last_notified_ms = time_now_ms - notify_deferred = self.notify_deferred - - log_kv( - { - "notify": self.user_id, - "stream": stream_key, - "stream_id": stream_id, - "listeners": self.count_listeners(), - } - ) - users_woken_by_stream_counter.labels(stream_key).inc() + listeners = self.listeners + self.listeners = set() - with PreserveLoggingContext(): - self.notify_deferred = ObservableDeferred(defer.Deferred()) - notify_deferred.callback(self.current_token) + return listeners def remove(self, notifier: "Notifier") -> None: """Remove this listener from all the indexes in the Notifier @@ -176,12 +158,15 @@ class _NotifierUserStream: lst = notifier.room_to_user_streams.get(room, set()) lst.discard(self) + if not lst: + notifier.room_to_user_streams.pop(room, None) + notifier.user_to_user_stream.pop(self.user_id) def count_listeners(self) -> int: - return len(self.notify_deferred.observers()) + return len(self.listeners) - def new_listener(self, token: StreamToken) -> _NotificationListener: + def new_listener(self, token: StreamToken) -> "Deferred[StreamToken]": """Returns a deferred that is resolved when there is a new token greater than the given token. @@ -191,10 +176,17 @@ class _NotifierUserStream: """ # Immediately wake up stream if something has already since happened # since their last token. - if self.last_notified_token != token: - return _NotificationListener(defer.succeed(self.current_token)) - else: - return _NotificationListener(self.notify_deferred.observe()) + if token != self.current_token: + return defer.succeed(self.current_token) + + # Create a new deferred and add it to the set of listeners. We add a + # cancel handler to remove it from the set again, to handle timeouts. + deferred: "Deferred[StreamToken]" = Deferred( + canceller=lambda d: self.listeners.discard(d) + ) + self.listeners.add(deferred) + + return deferred @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -247,6 +239,7 @@ class Notifier: # List of callbacks to be notified when a lock is released self._lock_released_callback: List[Callable[[str, str, str], None]] = [] + self.reactor = hs.get_reactor() self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() self._pusher_pool = hs.get_pusherpool() @@ -342,14 +335,25 @@ class Notifier: # Wake up all related user stream notifiers user_streams = self.room_to_user_streams.get(room_id, set()) time_now_ms = self.clock.time_msec() + current_token = self.event_sources.get_current_token() + + listeners: List["Deferred[StreamToken]"] = [] for user_stream in user_streams: try: - user_stream.notify( - StreamKeyType.UN_PARTIAL_STATED_ROOMS, new_token, time_now_ms + listeners.extend( + user_stream.update_and_fetch_deferreds(current_token, time_now_ms) ) except Exception: logger.exception("Failed to notify listener") + with PreserveLoggingContext(): + for listener in listeners: + listener.callback(current_token) + + users_woken_by_stream_counter.labels(StreamKeyType.UN_PARTIAL_STATED_ROOMS).inc( + len(user_streams) + ) + # Poke the replication so that other workers also see the write to # the un-partial-stated rooms stream. self.notify_replication() @@ -518,16 +522,22 @@ class Notifier: users = users or [] rooms = rooms or [] - with Measure(self.clock, "on_new_event"): - user_streams = set() + user_streams: Set[_NotifierUserStream] = set() - log_kv( - { - "waking_up_explicit_users": len(users), - "waking_up_explicit_rooms": len(rooms), - } - ) + log_kv( + { + "waking_up_explicit_users": len(users), + "waking_up_explicit_rooms": len(rooms), + "users": shortstr(users), + "rooms": shortstr(rooms), + "stream": stream_key, + "stream_id": new_token, + } + ) + # Only calculate which user streams to wake up if there are, in fact, + # any user streams registered. + if self.user_to_user_stream or self.room_to_user_streams: for user in users: user_stream = self.user_to_user_stream.get(str(user)) if user_stream is not None: @@ -544,25 +554,40 @@ class Notifier: ) time_now_ms = self.clock.time_msec() + current_token = self.event_sources.get_current_token() + listeners: List["Deferred[StreamToken]"] = [] for user_stream in user_streams: try: - user_stream.notify(stream_key, new_token, time_now_ms) + listeners.extend( + user_stream.update_and_fetch_deferreds( + current_token, time_now_ms + ) + ) except Exception: logger.exception("Failed to notify listener") - self.notify_replication() + # We resolve all these deferreds in one go so that we only need to + # call `PreserveLoggingContext` once, as it has a bunch of overhead + # (to calculate performance stats) + if listeners: + with PreserveLoggingContext(): + for listener in listeners: + listener.callback(current_token) - # Notify appservices. - try: - self.appservice_handler.notify_interested_services_ephemeral( - stream_key, - new_token, - users, - ) - except Exception: - logger.exception( - "Error notifying application services of ephemeral events" - ) + if user_streams: + users_woken_by_stream_counter.labels(stream_key).inc(len(user_streams)) + + self.notify_replication() + + # Notify appservices. + try: + self.appservice_handler.notify_interested_services_ephemeral( + stream_key, + new_token, + users, + ) + except Exception: + logger.exception("Error notifying application services of ephemeral events") def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happened @@ -586,6 +611,7 @@ class Notifier: if room_ids is None: room_ids = await self.store.get_rooms_for_user(user_id) user_stream = _NotifierUserStream( + reactor=self.reactor, user_id=user_id, rooms=room_ids, current_token=current_token, @@ -608,8 +634,8 @@ class Notifier: # Now we wait for the _NotifierUserStream to be told there # is a new token. listener = user_stream.new_listener(prev_token) - listener.deferred = timeout_deferred( - listener.deferred, + listener = timeout_deferred( + listener, (end_time - now) / 1000.0, self.hs.get_reactor(), ) @@ -622,7 +648,7 @@ class Notifier: ) with PreserveLoggingContext(): - await listener.deferred + await listener log_kv( { diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 34ab637c3d..8249d5e84f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -52,6 +52,7 @@ from synapse.events.snapshot import EventContext from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership +from synapse.storage.invite_rule import InviteRule from synapse.storage.roommember import ProfileInfo from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator from synapse.types import JsonValue @@ -191,9 +192,17 @@ class BulkPushRuleEvaluator: # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited - if event.type == EventTypes.Member and event.membership == Membership.INVITE: + if ( + event.is_state() + and event.type == EventTypes.Member + and event.membership == Membership.INVITE + ): invited = event.state_key - if invited and self.hs.is_mine_id(invited) and invited not in local_users: + invite_config = await self.store.get_invite_config_for_user(invited) + if invite_config.get_invite_rule(event.sender) != InviteRule.ALLOW: + # Invite was blocked or ignored, never notify. + return {} + if self.hs.is_mine_id(invited) and invited not in local_users: local_users.append(invited) if not local_users: @@ -304,9 +313,9 @@ class BulkPushRuleEvaluator: if relation_type == "m.thread" and event.content.get( "m.relates_to", {} ).get("is_falling_back", False): - related_events["m.in_reply_to"][ - "im.vector.is_falling_back" - ] = "" + related_events["m.in_reply_to"]["im.vector.is_falling_back"] = ( + "" + ) return related_events @@ -371,8 +380,9 @@ class BulkPushRuleEvaluator: "Deferred[Tuple[int, Tuple[dict, Optional[int]], Dict[str, Dict[str, JsonValue]], Mapping[str, ProfileInfo]]]", gather_results( ( - run_in_background( # type: ignore[call-arg] - self.store.get_number_joined_users_in_room, event.room_id # type: ignore[arg-type] + run_in_background( # type: ignore[call-overload] + self.store.get_number_joined_users_in_room, + event.room_id, # type: ignore[arg-type] ), run_in_background( self._get_power_levels_and_sender_level, @@ -381,10 +391,10 @@ class BulkPushRuleEvaluator: event_id_to_event, ), run_in_background(self._related_events, event), - run_in_background( # type: ignore[call-arg] + run_in_background( # type: ignore[call-overload] self.store.get_subset_users_in_room_with_profiles, - event.room_id, # type: ignore[arg-type] - rules_by_user.keys(), # type: ignore[arg-type] + event.room_id, + rules_by_user.keys(), ), ), consumeErrors=True, @@ -435,6 +445,7 @@ class BulkPushRuleEvaluator: self._related_event_match_enabled, event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag + self.hs.config.experimental.msc4210_enabled, ) for uid, rules in rules_by_user.items(): diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py deleted file mode 100644
index 0a14c534f7..0000000000 --- a/synapse/push/emailpusher.py +++ /dev/null
@@ -1,331 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -import logging -from typing import TYPE_CHECKING, Dict, List, Optional - -from twisted.internet.error import AlreadyCalled, AlreadyCancelled -from twisted.internet.interfaces import IDelayedCall - -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams -from synapse.push.mailer import Mailer -from synapse.push.push_types import EmailReason -from synapse.storage.databases.main.event_push_actions import EmailPushAction -from synapse.util.threepids import validate_email - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - -# THROTTLE is the minimum time between mail notifications sent for a given room. -# Each room maintains its own throttle counter, but each new mail notification -# sends the pending notifications for all rooms. -THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h -# THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours -THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day - -# If no event triggers a notification for this long after the previous, -# the throttle is released. -# 12 hours - a gap of 12 hours in conversation is surely enough to merit a new -# notification when things get going again... -THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000 - -# does each email include all unread notifs, or just the ones which have happened -# since the last mail? -# XXX: this is currently broken as it includes ones from parted rooms(!) -INCLUDE_ALL_UNREAD_NOTIFS = False - - -class EmailPusher(Pusher): - """ - A pusher that sends email notifications about events (approximately) - when they happen. - This shares quite a bit of code with httpusher: it would be good to - factor out the common parts - """ - - def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer): - super().__init__(hs, pusher_config) - self.mailer = mailer - - self.store = self.hs.get_datastores().main - self.email = pusher_config.pushkey - self.timed_call: Optional[IDelayedCall] = None - self.throttle_params: Dict[str, ThrottleParams] = {} - self._inited = False - - self._is_processing = False - - # Make sure that the email is valid. - try: - validate_email(self.email) - except ValueError: - raise PusherConfigException("Invalid email") - - self._delay_before_mail_ms = self.hs.config.email.notif_delay_before_mail_ms - - def on_started(self, should_check_for_notifs: bool) -> None: - """Called when this pusher has been started. - - Args: - should_check_for_notifs: Whether we should immediately - check for push to send. Set to False only if it's known there - is nothing to send - """ - if should_check_for_notifs and self.mailer is not None: - self._start_processing() - - def on_stop(self) -> None: - if self.timed_call: - try: - self.timed_call.cancel() - except (AlreadyCalled, AlreadyCancelled): - pass - self.timed_call = None - - def on_new_receipts(self) -> None: - # We could wake up and cancel the timer but there tend to be quite a - # lot of read receipts so it's probably less work to just let the - # timer fire - pass - - def on_timer(self) -> None: - self.timed_call = None - self._start_processing() - - def _start_processing(self) -> None: - if self._is_processing: - return - - run_as_background_process("emailpush.process", self._process) - - def _pause_processing(self) -> None: - """Used by tests to temporarily pause processing of events. - - Asserts that its not currently processing. - """ - assert not self._is_processing - self._is_processing = True - - def _resume_processing(self) -> None: - """Used by tests to resume processing of events after pausing.""" - assert self._is_processing - self._is_processing = False - self._start_processing() - - async def _process(self) -> None: - # we should never get here if we are already processing - assert not self._is_processing - - try: - self._is_processing = True - - if not self._inited: - # this is our first loop: load up the throttle params - assert self.pusher_id is not None - self.throttle_params = await self.store.get_throttle_params_by_room( - self.pusher_id - ) - self._inited = True - - # if the max ordering changes while we're running _unsafe_process, - # call it again, and so on until we've caught up. - while True: - starting_max_ordering = self.max_stream_ordering - try: - await self._unsafe_process() - except Exception: - logger.exception("Exception processing notifs") - if self.max_stream_ordering == starting_max_ordering: - break - finally: - self._is_processing = False - - async def _unsafe_process(self) -> None: - """ - Main logic of the push loop without the wrapper function that sets - up logging, measures and guards against multiple instances of it - being run. - """ - start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering - unprocessed = ( - await self.store.get_unread_push_actions_for_user_in_range_for_email( - self.user_id, start, self.max_stream_ordering - ) - ) - - soonest_due_at: Optional[int] = None - - if not unprocessed: - await self.save_last_stream_ordering_and_success(self.max_stream_ordering) - return - - for push_action in unprocessed: - received_at = push_action.received_ts - if received_at is None: - received_at = 0 - notif_ready_at = received_at + self._delay_before_mail_ms - - room_ready_at = self.room_ready_to_notify_at(push_action.room_id) - - should_notify_at = max(notif_ready_at, room_ready_at) - - if should_notify_at <= self.clock.time_msec(): - # one of our notifications is ready for sending, so we send - # *one* email updating the user on their notifications, - # we then consider all previously outstanding notifications - # to be delivered. - - reason: EmailReason = { - "room_id": push_action.room_id, - "now": self.clock.time_msec(), - "received_at": received_at, - "delay_before_mail_ms": self._delay_before_mail_ms, - "last_sent_ts": self.get_room_last_sent_ts(push_action.room_id), - "throttle_ms": self.get_room_throttle_ms(push_action.room_id), - } - - await self.send_notification(unprocessed, reason) - - await self.save_last_stream_ordering_and_success( - max(ea.stream_ordering for ea in unprocessed) - ) - - # we update the throttle on all the possible unprocessed push actions - for ea in unprocessed: - await self.sent_notif_update_throttle(ea.room_id, ea) - break - else: - if soonest_due_at is None or should_notify_at < soonest_due_at: - soonest_due_at = should_notify_at - - if self.timed_call is not None: - try: - self.timed_call.cancel() - except (AlreadyCalled, AlreadyCancelled): - pass - self.timed_call = None - - if soonest_due_at is not None: - self.timed_call = self.hs.get_reactor().callLater( - self.seconds_until(soonest_due_at), self.on_timer - ) - - async def save_last_stream_ordering_and_success( - self, last_stream_ordering: int - ) -> None: - self.last_stream_ordering = last_stream_ordering - pusher_still_exists = ( - await self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.email, - self.user_id, - last_stream_ordering, - self.clock.time_msec(), - ) - ) - if not pusher_still_exists: - # The pusher has been deleted while we were processing, so - # lets just stop and return. - self.on_stop() - - def seconds_until(self, ts_msec: int) -> float: - secs = (ts_msec - self.clock.time_msec()) / 1000 - return max(secs, 0) - - def get_room_throttle_ms(self, room_id: str) -> int: - if room_id in self.throttle_params: - return self.throttle_params[room_id].throttle_ms - else: - return 0 - - def get_room_last_sent_ts(self, room_id: str) -> int: - if room_id in self.throttle_params: - return self.throttle_params[room_id].last_sent_ts - else: - return 0 - - def room_ready_to_notify_at(self, room_id: str) -> int: - """ - Determines whether throttling should prevent us from sending an email - for the given room - - Returns: - The timestamp when we are next allowed to send an email notif - for this room - """ - last_sent_ts = self.get_room_last_sent_ts(room_id) - throttle_ms = self.get_room_throttle_ms(room_id) - - may_send_at = last_sent_ts + throttle_ms - return may_send_at - - async def sent_notif_update_throttle( - self, room_id: str, notified_push_action: EmailPushAction - ) -> None: - # We have sent a notification, so update the throttle accordingly. - # If the event that triggered the notif happened more than - # THROTTLE_RESET_AFTER_MS after the previous one that triggered a - # notif, we release the throttle. Otherwise, the throttle is increased. - time_of_previous_notifs = await self.store.get_time_of_last_push_action_before( - notified_push_action.stream_ordering - ) - - time_of_this_notifs = notified_push_action.received_ts - - if time_of_previous_notifs is not None and time_of_this_notifs is not None: - gap = time_of_this_notifs - time_of_previous_notifs - else: - # if we don't know the arrival time of one of the notifs (it was not - # stored prior to email notification code) then assume a gap of - # zero which will just not reset the throttle - gap = 0 - - current_throttle_ms = self.get_room_throttle_ms(room_id) - - if gap > THROTTLE_RESET_AFTER_MS: - new_throttle_ms = self._delay_before_mail_ms - else: - if current_throttle_ms == 0: - new_throttle_ms = self._delay_before_mail_ms - else: - new_throttle_ms = min( - current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS - ) - self.throttle_params[room_id] = ThrottleParams( - self.clock.time_msec(), - new_throttle_ms, - ) - assert self.pusher_id is not None - await self.store.set_throttle_params( - self.pusher_id, room_id, self.throttle_params[room_id] - ) - - async def send_notification( - self, push_actions: List[EmailPushAction], reason: EmailReason - ) -> None: - logger.info("Sending notif email for user %r", self.user_id) - - await self.mailer.send_notification_mail( - self.app_id, self.user_id, self.email, push_actions, reason - ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index dd9b64d6ef..7df8a128c9 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py
@@ -127,6 +127,11 @@ class HttpPusher(Pusher): if self.data is None: raise PusherConfigException("'data' key can not be null for HTTP pusher") + # Check if badge counts should be disabled for this push gateway + self.disable_badge_count = self.hs.config.experimental.msc4076_enabled and bool( + self.data.get("org.matrix.msc4076.disable_badge_count", False) + ) + self.name = "%s/%s/%s" % ( pusher_config.user_name, pusher_config.app_id, @@ -200,6 +205,12 @@ class HttpPusher(Pusher): if self._is_processing: return + # Check if we are trying, but failing, to contact the pusher. If so, we + # don't try and start processing immediately and instead wait for the + # retry loop to try again later (which is controlled by the timer). + if self.failing_since and self.timed_call and self.timed_call.active(): + return + run_as_background_process("httppush.process", self._process) async def _process(self) -> None: @@ -461,9 +472,10 @@ class HttpPusher(Pusher): content: JsonDict = { "event_id": event.event_id, "room_id": event.room_id, - "counts": {"unread": badge}, "prio": priority, } + if not self.disable_badge_count: + content["counts"] = {"unread": badge} # event_id_only doesn't include the tweaks, so override them. tweaks = {} else: @@ -478,11 +490,11 @@ class HttpPusher(Pusher): "type": event.type, "sender": event.user_id, "prio": priority, - "counts": { - "unread": badge, - # 'missed_calls': 2 - }, } + if not self.disable_badge_count: + content["counts"] = { + "unread": badge, + } if event.type == "m.room.member" and event.is_state(): content["membership"] = event.content["membership"] content["user_is_target"] = event.state_key == self.user_id diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py deleted file mode 100644
index cf611bd90b..0000000000 --- a/synapse/push/mailer.py +++ /dev/null
@@ -1,1003 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -import logging -import urllib.parse -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar - -import bleach -import jinja2 -from markupsafe import Markup -from prometheus_client import Counter - -from synapse.api.constants import EventContentFields, EventTypes, Membership, RoomTypes -from synapse.api.errors import StoreError -from synapse.config.emailconfig import EmailSubjectConfig -from synapse.events import EventBase -from synapse.push.presentable_names import ( - calculate_room_name, - descriptor_from_member_events, - name_from_member_event, -) -from synapse.push.push_types import ( - EmailReason, - MessageVars, - NotifVars, - RoomVars, - TemplateVars, -) -from synapse.storage.databases.main.event_push_actions import EmailPushAction -from synapse.types import StateMap, UserID -from synapse.types.state import StateFilter -from synapse.util.async_helpers import concurrently_execute -from synapse.visibility import filter_events_for_client - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - -T = TypeVar("T") - -emails_sent_counter = Counter( - "synapse_emails_sent_total", - "Emails sent by type", - ["type"], -) - - -CONTEXT_BEFORE = 1 -CONTEXT_AFTER = 1 - -# From https://github.com/matrix-org/matrix-react-sdk/blob/master/src/HtmlUtils.js -ALLOWED_TAGS = [ - "font", # custom to matrix for IRC-style font coloring - "del", # for markdown - # deliberately no h1/h2 to stop people shouting. - "h3", - "h4", - "h5", - "h6", - "blockquote", - "p", - "a", - "ul", - "ol", - "nl", - "li", - "b", - "i", - "u", - "strong", - "em", - "strike", - "code", - "hr", - "br", - "div", - "table", - "thead", - "caption", - "tbody", - "tr", - "th", - "td", - "pre", -] -ALLOWED_ATTRS = { - # custom ones first: - "font": ["color"], # custom to matrix - "a": ["href", "name", "target"], # remote target: custom to matrix - # We don't currently allow img itself by default, but this - # would make sense if we did - "img": ["src"], -} -# When bleach release a version with this option, we can specify schemes -# ALLOWED_SCHEMES = ["http", "https", "ftp", "mailto"] - - -class Mailer: - def __init__( - self, - hs: "HomeServer", - app_name: str, - template_html: jinja2.Template, - template_text: jinja2.Template, - ): - self.hs = hs - self.template_html = template_html - self.template_text = template_text - - self.send_email_handler = hs.get_send_email_handler() - self.store = self.hs.get_datastores().main - self._state_storage_controller = self.hs.get_storage_controllers().state - self.macaroon_gen = self.hs.get_macaroon_generator() - self.state_handler = self.hs.get_state_handler() - self._storage_controllers = hs.get_storage_controllers() - self.app_name = app_name - self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects - - logger.info("Created Mailer for app_name %s" % app_name) - - emails_sent_counter.labels("password_reset") - - async def send_password_reset_mail( - self, email_address: str, token: str, client_secret: str, sid: str - ) -> None: - """Send an email with a password reset link to a user - - Args: - email_address: Email address we're sending the password - reset to - token: Unique token generated by the server to verify - the email was received - client_secret: Unique token generated by the client to - group together multiple email sending attempts - sid: The generated session ID - """ - params = {"token": token, "client_secret": client_secret, "sid": sid} - link = ( - self.hs.config.server.public_baseurl - + "_synapse/client/password_reset/email/submit_token?%s" - % urllib.parse.urlencode(params) - ) - - template_vars: TemplateVars = {"link": link} - - emails_sent_counter.labels("password_reset").inc() - - await self.send_email( - email_address, - self.email_subjects.password_reset - % {"server_name": self.hs.config.server.server_name, "app": self.app_name}, - template_vars, - ) - - emails_sent_counter.labels("registration") - - async def send_registration_mail( - self, email_address: str, token: str, client_secret: str, sid: str - ) -> None: - """Send an email with a registration confirmation link to a user - - Args: - email_address: Email address we're sending the registration - link to - token: Unique token generated by the server to verify - the email was received - client_secret: Unique token generated by the client to - group together multiple email sending attempts - sid: The generated session ID - """ - params = {"token": token, "client_secret": client_secret, "sid": sid} - link = ( - self.hs.config.server.public_baseurl - + "_matrix/client/unstable/registration/email/submit_token?%s" - % urllib.parse.urlencode(params) - ) - - template_vars: TemplateVars = {"link": link} - - emails_sent_counter.labels("registration").inc() - - await self.send_email( - email_address, - self.email_subjects.email_validation - % {"server_name": self.hs.config.server.server_name, "app": self.app_name}, - template_vars, - ) - - emails_sent_counter.labels("already_in_use") - - async def send_already_in_use_mail(self, email_address: str) -> None: - """Send an email if the address is already bound to an user account - - Args: - email_address: Email address we're sending to the "already in use" mail - """ - - await self.send_email( - email_address, - self.email_subjects.email_already_in_use - % {"server_name": self.hs.config.server.server_name, "app": self.app_name}, - {}, - ) - - emails_sent_counter.labels("add_threepid") - - async def send_add_threepid_mail( - self, email_address: str, token: str, client_secret: str, sid: str - ) -> None: - """Send an email with a validation link to a user for adding a 3pid to their account - - Args: - email_address: Email address we're sending the validation link to - - token: Unique token generated by the server to verify the email was received - - client_secret: Unique token generated by the client to group together - multiple email sending attempts - - sid: The generated session ID - """ - params = {"token": token, "client_secret": client_secret, "sid": sid} - link = ( - self.hs.config.server.public_baseurl - + "_matrix/client/unstable/add_threepid/email/submit_token?%s" - % urllib.parse.urlencode(params) - ) - - template_vars: TemplateVars = {"link": link} - - emails_sent_counter.labels("add_threepid").inc() - - await self.send_email( - email_address, - self.email_subjects.email_validation - % {"server_name": self.hs.config.server.server_name, "app": self.app_name}, - template_vars, - ) - - emails_sent_counter.labels("notification") - - async def send_notification_mail( - self, - app_id: str, - user_id: str, - email_address: str, - push_actions: Iterable[EmailPushAction], - reason: EmailReason, - ) -> None: - """ - Send email regarding a user's room notifications - - Params: - app_id: The application receiving the notification. - user_id: The user receiving the notification. - email_address: The email address receiving the notification. - push_actions: All outstanding notifications. - reason: The notification that was ready and is the cause of an email - being sent. - """ - rooms_in_order = deduped_ordered_list([pa.room_id for pa in push_actions]) - - notif_events = await self.store.get_events([pa.event_id for pa in push_actions]) - - notifs_by_room: Dict[str, List[EmailPushAction]] = {} - for pa in push_actions: - notifs_by_room.setdefault(pa.room_id, []).append(pa) - - # collect the current state for all the rooms in which we have - # notifications - state_by_room = {} - - try: - user_display_name = await self.store.get_profile_displayname( - UserID.from_string(user_id) - ) - if user_display_name is None: - user_display_name = user_id - except StoreError: - user_display_name = user_id - - async def _fetch_room_state(room_id: str) -> None: - room_state = await self._state_storage_controller.get_current_state_ids( - room_id - ) - state_by_room[room_id] = room_state - - # Run at most 3 of these at once: sync does 10 at a time but email - # notifs are much less realtime than sync so we can afford to wait a bit. - await concurrently_execute(_fetch_room_state, rooms_in_order, 3) - - # actually sort our so-called rooms_in_order list, most recent room first - rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1].received_ts or 0)) - - rooms: List[RoomVars] = [] - - for r in rooms_in_order: - roomvars = await self._get_room_vars( - r, user_id, notifs_by_room[r], notif_events, state_by_room[r] - ) - rooms.append(roomvars) - - reason["room_name"] = await calculate_room_name( - self.store, - state_by_room[reason["room_id"]], - user_id, - fallback_to_members=True, - ) - - if len(notifs_by_room) == 1: - # Only one room has new stuff - room_id = list(notifs_by_room.keys())[0] - - summary_text = await self._make_summary_text_single_room( - room_id, - notifs_by_room[room_id], - state_by_room[room_id], - notif_events, - user_id, - ) - else: - summary_text = await self._make_summary_text( - notifs_by_room, state_by_room, notif_events, reason - ) - - unsubscribe_link = self._make_unsubscribe_link(user_id, app_id, email_address) - - template_vars: TemplateVars = { - "user_display_name": user_display_name, - "unsubscribe_link": unsubscribe_link, - "summary_text": summary_text, - "rooms": rooms, - "reason": reason, - } - - emails_sent_counter.labels("notification").inc() - - await self.send_email( - email_address, summary_text, template_vars, unsubscribe_link - ) - - async def send_email( - self, - email_address: str, - subject: str, - extra_template_vars: TemplateVars, - unsubscribe_link: Optional[str] = None, - ) -> None: - """Send an email with the given information and template text""" - template_vars: TemplateVars = { - "app_name": self.app_name, - "server_name": self.hs.config.server.server_name, - } - - template_vars.update(extra_template_vars) - - html_text = self.template_html.render(**template_vars) - plain_text = self.template_text.render(**template_vars) - - await self.send_email_handler.send_email( - email_address=email_address, - subject=subject, - app_name=self.app_name, - html=html_text, - text=plain_text, - # Include the List-Unsubscribe header which some clients render in the UI. - # Per RFC 2369, this can be a URL or mailto URL. See - # https://www.rfc-editor.org/rfc/rfc2369.html#section-3.2 - # - # It is preferred to use email, but Synapse doesn't support incoming email. - # - # Also include the List-Unsubscribe-Post header from RFC 8058. See - # https://www.rfc-editor.org/rfc/rfc8058.html#section-3.1 - # - # Note that many email clients will not render the unsubscribe link - # unless DKIM, etc. is properly setup. - additional_headers=( - { - "List-Unsubscribe-Post": "List-Unsubscribe=One-Click", - "List-Unsubscribe": f"<{unsubscribe_link}>", - } - if unsubscribe_link - else None - ), - ) - - async def _get_room_vars( - self, - room_id: str, - user_id: str, - notifs: Iterable[EmailPushAction], - notif_events: Dict[str, EventBase], - room_state_ids: StateMap[str], - ) -> RoomVars: - """ - Generate the variables for notifications on a per-room basis. - - Args: - room_id: The room ID - user_id: The user receiving the notification. - notifs: The outstanding push actions for this room. - notif_events: The events related to the above notifications. - room_state_ids: The event IDs of the current room state. - - Returns: - A dictionary to be added to the template context. - """ - - # Check if one of the notifs is an invite event for the user. - is_invite = False - for n in notifs: - ev = notif_events[n.event_id] - if ev.type == EventTypes.Member and ev.state_key == user_id: - if ev.content.get("membership") == Membership.INVITE: - is_invite = True - break - - room_name = await calculate_room_name(self.store, room_state_ids, user_id) - - room_vars: RoomVars = { - "title": room_name, - "hash": string_ordinal_total(room_id), # See sender avatar hash - "notifs": [], - "invite": is_invite, - "link": self._make_room_link(room_id), - "avatar_url": await self._get_room_avatar(room_state_ids), - } - - if not is_invite: - for n in notifs: - notifvars = await self._get_notif_vars( - n, user_id, notif_events[n.event_id], room_state_ids - ) - - # merge overlapping notifs together. - # relies on the notifs being in chronological order. - merge = False - if room_vars["notifs"] and "messages" in room_vars["notifs"][-1]: - prev_messages = room_vars["notifs"][-1]["messages"] - for message in notifvars["messages"]: - pm = list( - filter(lambda pm: pm["id"] == message["id"], prev_messages) - ) - if pm: - if not message["is_historical"]: - pm[0]["is_historical"] = False - merge = True - elif merge: - # we're merging, so append any remaining messages - # in this notif to the previous one - prev_messages.append(message) - - if not merge: - room_vars["notifs"].append(notifvars) - - return room_vars - - async def _get_room_avatar( - self, - room_state_ids: StateMap[str], - ) -> Optional[str]: - """ - Retrieve the avatar url for this room---if it exists. - - Args: - room_state_ids: The event IDs of the current room state. - - Returns: - room's avatar url if it's present and a string; otherwise None. - """ - event_id = room_state_ids.get((EventTypes.RoomAvatar, "")) - if event_id: - ev = await self.store.get_event(event_id) - url = ev.content.get("url") - if isinstance(url, str): - return url - return None - - async def _get_notif_vars( - self, - notif: EmailPushAction, - user_id: str, - notif_event: EventBase, - room_state_ids: StateMap[str], - ) -> NotifVars: - """ - Generate the variables for a single notification. - - Args: - notif: The outstanding notification for this room. - user_id: The user receiving the notification. - notif_event: The event related to the above notification. - room_state_ids: The event IDs of the current room state. - - Returns: - A dictionary to be added to the template context. - """ - - results = await self.store.get_events_around( - notif.room_id, - notif.event_id, - before_limit=CONTEXT_BEFORE, - after_limit=CONTEXT_AFTER, - ) - - ret: NotifVars = { - "link": self._make_notif_link(notif), - "ts": notif.received_ts, - "messages": [], - } - - the_events = await filter_events_for_client( - self._storage_controllers, - user_id, - results.events_before, - ) - the_events.append(notif_event) - - for event in the_events: - messagevars = await self._get_message_vars(notif, event, room_state_ids) - if messagevars is not None: - ret["messages"].append(messagevars) - - return ret - - async def _get_message_vars( - self, notif: EmailPushAction, event: EventBase, room_state_ids: StateMap[str] - ) -> Optional[MessageVars]: - """ - Generate the variables for a single event, if possible. - - Args: - notif: The outstanding notification for this room. - event: The event under consideration. - room_state_ids: The event IDs of the current room state. - - Returns: - A dictionary to be added to the template context, or None if the - event cannot be processed. - """ - if event.type != EventTypes.Message and event.type != EventTypes.Encrypted: - return None - - # Get the sender's name and avatar from the room state. - type_state_key = ("m.room.member", event.sender) - sender_state_event_id = room_state_ids.get(type_state_key) - if sender_state_event_id: - sender_state_event: Optional[EventBase] = await self.store.get_event( - sender_state_event_id - ) - else: - # Attempt to check the historical state for the room. - historical_state = await self._state_storage_controller.get_state_for_event( - event.event_id, StateFilter.from_types((type_state_key,)) - ) - sender_state_event = historical_state.get(type_state_key) - - if sender_state_event: - sender_name = name_from_member_event(sender_state_event) - sender_avatar_url: Optional[str] = sender_state_event.content.get( - "avatar_url" - ) - else: - # No state could be found, fallback to the MXID. - sender_name = event.sender - sender_avatar_url = None - - # 'hash' for deterministically picking default images: use - # sender_hash % the number of default images to choose from - sender_hash = string_ordinal_total(event.sender) - - ret: MessageVars = { - "event_type": event.type, - "is_historical": event.event_id != notif.event_id, - "id": event.event_id, - "ts": event.origin_server_ts, - "sender_name": sender_name, - "sender_avatar_url": sender_avatar_url, - "sender_hash": sender_hash, - } - - # Encrypted messages don't have any additional useful information. - if event.type == EventTypes.Encrypted: - return ret - - msgtype = event.content.get("msgtype") - if not isinstance(msgtype, str): - msgtype = None - - ret["msgtype"] = msgtype - - if msgtype == "m.text": - self._add_text_message_vars(ret, event) - elif msgtype == "m.image": - self._add_image_message_vars(ret, event) - - if "body" in event.content: - ret["body_text_plain"] = event.content["body"] - - return ret - - def _add_text_message_vars( - self, messagevars: MessageVars, event: EventBase - ) -> None: - """ - Potentially add a sanitised message body to the message variables. - - Args: - messagevars: The template context to be modified. - event: The event under consideration. - """ - msgformat = event.content.get("format") - if not isinstance(msgformat, str): - msgformat = None - - formatted_body = event.content.get("formatted_body") - body = event.content.get("body") - - if msgformat == "org.matrix.custom.html" and formatted_body: - messagevars["body_text_html"] = safe_markup(formatted_body) - elif body: - messagevars["body_text_html"] = safe_text(body) - - def _add_image_message_vars( - self, messagevars: MessageVars, event: EventBase - ) -> None: - """ - Potentially add an image URL to the message variables. - - Args: - messagevars: The template context to be modified. - event: The event under consideration. - """ - if "url" in event.content: - messagevars["image_url"] = event.content["url"] - - async def _make_summary_text_single_room( - self, - room_id: str, - notifs: List[EmailPushAction], - room_state_ids: StateMap[str], - notif_events: Dict[str, EventBase], - user_id: str, - ) -> str: - """ - Make a summary text for the email when only a single room has notifications. - - Args: - room_id: The ID of the room. - notifs: The push actions for this room. - room_state_ids: The state map for the room. - notif_events: A map of event ID -> notification event. - user_id: The user receiving the notification. - - Returns: - The summary text. - """ - # If the room has some kind of name, use it, but we don't - # want the generated-from-names one here otherwise we'll - # end up with, "new message from Bob in the Bob room" - room_name = await calculate_room_name( - self.store, room_state_ids, user_id, fallback_to_members=False - ) - - # See if one of the notifs is an invite event for the user - invite_event = None - for n in notifs: - ev = notif_events[n.event_id] - if ev.type == EventTypes.Member and ev.state_key == user_id: - if ev.content.get("membership") == Membership.INVITE: - invite_event = ev - break - - if invite_event: - inviter_member_event_id = room_state_ids.get( - ("m.room.member", invite_event.sender) - ) - inviter_name = invite_event.sender - if inviter_member_event_id: - inviter_member_event = await self.store.get_event( - inviter_member_event_id, allow_none=True - ) - if inviter_member_event: - inviter_name = name_from_member_event(inviter_member_event) - - if room_name is None: - return self.email_subjects.invite_from_person % { - "person": inviter_name, - "app": self.app_name, - } - - # If the room is a space, it gets a slightly different topic. - create_event_id = room_state_ids.get(("m.room.create", "")) - if create_event_id: - create_event = await self.store.get_event( - create_event_id, allow_none=True - ) - if ( - create_event - and create_event.content.get(EventContentFields.ROOM_TYPE) - == RoomTypes.SPACE - ): - return self.email_subjects.invite_from_person_to_space % { - "person": inviter_name, - "space": room_name, - "app": self.app_name, - } - - return self.email_subjects.invite_from_person_to_room % { - "person": inviter_name, - "room": room_name, - "app": self.app_name, - } - - if len(notifs) == 1: - # There is just the one notification, so give some detail - sender_name = None - event = notif_events[notifs[0].event_id] - if ("m.room.member", event.sender) in room_state_ids: - state_event_id = room_state_ids[("m.room.member", event.sender)] - state_event = await self.store.get_event(state_event_id) - sender_name = name_from_member_event(state_event) - - if sender_name is not None and room_name is not None: - return self.email_subjects.message_from_person_in_room % { - "person": sender_name, - "room": room_name, - "app": self.app_name, - } - elif sender_name is not None: - return self.email_subjects.message_from_person % { - "person": sender_name, - "app": self.app_name, - } - - # The sender is unknown, just use the room name (or ID). - return self.email_subjects.messages_in_room % { - "room": room_name or room_id, - "app": self.app_name, - } - else: - # There's more than one notification for this room, so just - # say there are several - if room_name is not None: - return self.email_subjects.messages_in_room % { - "room": room_name, - "app": self.app_name, - } - - return await self._make_summary_text_from_member_events( - room_id, notifs, room_state_ids, notif_events - ) - - async def _make_summary_text( - self, - notifs_by_room: Dict[str, List[EmailPushAction]], - room_state_ids: Dict[str, StateMap[str]], - notif_events: Dict[str, EventBase], - reason: EmailReason, - ) -> str: - """ - Make a summary text for the email when multiple rooms have notifications. - - Args: - notifs_by_room: A map of room ID to the push actions for that room. - room_state_ids: A map of room ID to the state map for that room. - notif_events: A map of event ID -> notification event. - reason: The reason this notification is being sent. - - Returns: - The summary text. - """ - # Stuff's happened in multiple different rooms - # ...but we still refer to the 'reason' room which triggered the mail - if reason["room_name"] is not None: - return self.email_subjects.messages_in_room_and_others % { - "room": reason["room_name"], - "app": self.app_name, - } - - room_id = reason["room_id"] - return await self._make_summary_text_from_member_events( - room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events - ) - - async def _make_summary_text_from_member_events( - self, - room_id: str, - notifs: List[EmailPushAction], - room_state_ids: StateMap[str], - notif_events: Dict[str, EventBase], - ) -> str: - """ - Make a summary text for the email when only a single room has notifications. - - Args: - room_id: The ID of the room. - notifs: The push actions for this room. - room_state_ids: The state map for the room. - notif_events: A map of event ID -> notification event. - - Returns: - The summary text. - """ - # If the room doesn't have a name, say who the messages - # are from explicitly to avoid, "messages in the Bob room" - - # Find the latest event ID for each sender, note that the notifications - # are already in descending received_ts. - sender_ids = {} - for n in notifs: - sender = notif_events[n.event_id].sender - if sender not in sender_ids: - sender_ids[sender] = n.event_id - - # Get the actual member events (in order to calculate a pretty name for - # the room). - member_event_ids = [] - member_events = {} - for sender_id, event_id in sender_ids.items(): - type_state_key = ("m.room.member", sender_id) - sender_state_event_id = room_state_ids.get(type_state_key) - if sender_state_event_id: - member_event_ids.append(sender_state_event_id) - else: - # Attempt to check the historical state for the room. - historical_state = ( - await self._state_storage_controller.get_state_for_event( - event_id, StateFilter.from_types((type_state_key,)) - ) - ) - sender_state_event = historical_state.get(type_state_key) - if sender_state_event: - member_events[event_id] = sender_state_event - member_events.update(await self.store.get_events(member_event_ids)) - - if not member_events: - # No member events were found! Maybe the room is empty? - # Fallback to the room ID (note that if there was a room name this - # would already have been used previously). - return self.email_subjects.messages_in_room % { - "room": room_id, - "app": self.app_name, - } - - # There was a single sender. - if len(member_events) == 1: - return self.email_subjects.messages_from_person % { - "person": descriptor_from_member_events(member_events.values()), - "app": self.app_name, - } - - # There was more than one sender, use the first one and a tweaked template. - return self.email_subjects.messages_from_person_and_others % { - "person": descriptor_from_member_events(list(member_events.values())[:1]), - "app": self.app_name, - } - - def _make_room_link(self, room_id: str) -> str: - """ - Generate a link to open a room in the web client. - - Args: - room_id: The room ID to generate a link to. - - Returns: - A link to open a room in the web client. - """ - if self.hs.config.email.email_riot_base_url: - base_url = "%s/#/room" % (self.hs.config.email.email_riot_base_url) - elif self.app_name == "Vector": - # need /beta for Universal Links to work on iOS - base_url = "https://vector.im/beta/#/room" - else: - base_url = "https://matrix.to/#" - return "%s/%s" % (base_url, room_id) - - def _make_notif_link(self, notif: EmailPushAction) -> str: - """ - Generate a link to open an event in the web client. - - Args: - notif: The notification to generate a link for. - - Returns: - A link to open the notification in the web client. - """ - if self.hs.config.email.email_riot_base_url: - return "%s/#/room/%s/%s" % ( - self.hs.config.email.email_riot_base_url, - notif.room_id, - notif.event_id, - ) - elif self.app_name == "Vector": - # need /beta for Universal Links to work on iOS - return "https://vector.im/beta/#/room/%s/%s" % ( - notif.room_id, - notif.event_id, - ) - else: - return "https://matrix.to/#/%s/%s" % (notif.room_id, notif.event_id) - - def _make_unsubscribe_link( - self, user_id: str, app_id: str, email_address: str - ) -> str: - """ - Generate a link to unsubscribe from email notifications. - - Args: - user_id: The user receiving the notification. - app_id: The application receiving the notification. - email_address: The email address receiving the notification. - - Returns: - A link to unsubscribe from email notifications. - """ - params = { - "access_token": self.macaroon_gen.generate_delete_pusher_token( - user_id, app_id, email_address - ), - "app_id": app_id, - "pushkey": email_address, - } - - return "%s_synapse/client/unsubscribe?%s" % ( - self.hs.config.server.public_baseurl, - urllib.parse.urlencode(params), - ) - - -def safe_markup(raw_html: str) -> Markup: - """ - Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs. - - Args - raw_html: Unsafe HTML. - - Returns: - A Markup object ready to safely use in a Jinja template. - """ - return Markup( - bleach.linkify( - bleach.clean( - raw_html, - tags=ALLOWED_TAGS, - attributes=ALLOWED_ATTRS, - # bleach master has this, but it isn't released yet - # protocols=ALLOWED_SCHEMES, - strip=True, - ) - ) - ) - - -def safe_text(raw_text: str) -> Markup: - """ - Sanitise text (escape any HTML tags), and then linkify any bare URLs. - - Args - raw_text: Unsafe text which might include HTML markup. - - Returns: - A Markup object ready to safely use in a Jinja template. - """ - return Markup( - bleach.linkify(bleach.clean(raw_text, tags=[], attributes=[], strip=False)) - ) - - -def deduped_ordered_list(it: Iterable[T]) -> List[T]: - seen = set() - ret = [] - for item in it: - if item not in seen: - seen.add(item) - ret.append(item) - return ret - - -def string_ordinal_total(s: str) -> int: - tot = 0 - for c in s: - tot += ord(c) - return tot diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 1ef881f702..3f3e4a9234 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py
@@ -74,9 +74,13 @@ async def get_context_for_event( room_state = [] if ev.content.get("membership") == Membership.INVITE: - room_state = ev.unsigned.get("invite_room_state", []) + invite_room_state = ev.unsigned.get("invite_room_state", []) + if isinstance(invite_room_state, list): + room_state = invite_room_state elif ev.content.get("membership") == Membership.KNOCK: - room_state = ev.unsigned.get("knock_room_state", []) + knock_room_state = ev.unsigned.get("knock_room_state", []) + if isinstance(knock_room_state, list): + room_state = knock_room_state # Ideally we'd reuse the logic in `calculate_room_name`, but that gets # complicated to handle partial events vs pulling events from the DB. diff --git a/synapse/push/push_types.py b/synapse/push/push_types.py
index 201ec97219..57fa926a46 100644 --- a/synapse/push/push_types.py +++ b/synapse/push/push_types.py
@@ -18,9 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import List, Optional - -from typing_extensions import TypedDict +from typing import List, Optional, TypedDict class EmailReason(TypedDict, total=False): diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 9a5dd7a9d4..39bfe0dd33 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py
@@ -23,9 +23,7 @@ import logging from typing import TYPE_CHECKING, Callable, Dict, Optional from synapse.push import Pusher, PusherConfig -from synapse.push.emailpusher import EmailPusher from synapse.push.httppusher import HttpPusher -from synapse.push.mailer import Mailer if TYPE_CHECKING: from synapse.server import HomeServer @@ -42,17 +40,6 @@ class PusherFactory: "http": HttpPusher } - logger.info("email enable notifs: %r", hs.config.email.email_enable_notifs) - if hs.config.email.email_enable_notifs: - self.mailers: Dict[str, Mailer] = {} - - self._notif_template_html = hs.config.email.email_notif_template_html - self._notif_template_text = hs.config.email.email_notif_template_text - - self.pusher_types["email"] = self._create_email_pusher - - logger.info("defined email pusher type") - def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]: kind = pusher_config.kind f = self.pusher_types.get(kind, None) @@ -60,28 +47,3 @@ class PusherFactory: return None logger.debug("creating %s pusher for %r", kind, pusher_config) return f(self.hs, pusher_config) - - def _create_email_pusher( - self, _hs: "HomeServer", pusher_config: PusherConfig - ) -> EmailPusher: - app_name = self._app_name_from_pusherdict(pusher_config) - mailer = self.mailers.get(app_name) - if not mailer: - mailer = Mailer( - hs=self.hs, - app_name=app_name, - template_html=self._notif_template_html, - template_text=self._notif_template_text, - ) - self.mailers[app_name] = mailer - return EmailPusher(self.hs, pusher_config, mailer) - - def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str: - data = pusher_config.data - - if isinstance(data, dict): - brand = data.get("brand") - if isinstance(brand, str): - return brand - - return self.config.email.email_app_name diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 0a7541b4c7..bf80ac97a1 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py
@@ -34,7 +34,6 @@ from synapse.push.pusher import PusherFactory from synapse.replication.http.push import ReplicationRemovePusherRestServlet from synapse.types import JsonDict, RoomStreamToken, StrCollection from synapse.util.async_helpers import concurrently_execute -from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -122,11 +121,7 @@ class PusherPool: """ if kind == "email": - email_owner = await self.store.get_user_id_by_threepid( - "email", canonicalise_email(pushkey) - ) - if email_owner != user_id: - raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) + raise SynapseError(400, "Threepids are not supported on this server", "M_UNSUPPORTED") time_now_msec = self.clock.time_msec() diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index c9cf838255..d500051714 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py
@@ -1,7 +1,7 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING from synapse.http.server import JsonResource from synapse.replication.http import ( account_data, + delayed_events, devices, federation, login, @@ -64,3 +65,4 @@ class ReplicationRestResource(JsonResource): login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) + delayed_events.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 9aa8d90bfe..0002538680 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py
@@ -128,9 +128,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # We reserve `instance_name` as a parameter to sending requests, so we # assert here that sub classes don't try and use the name. - assert ( - "instance_name" not in self.PATH_ARGS - ), "`instance_name` is a reserved parameter name" + assert "instance_name" not in self.PATH_ARGS, ( + "`instance_name` is a reserved parameter name" + ) assert ( "instance_name" not in signature(self.__class__._serialize_payload).parameters diff --git a/synapse/replication/http/delayed_events.py b/synapse/replication/http/delayed_events.py new file mode 100644
index 0000000000..229022070c --- /dev/null +++ b/synapse/replication/http/delayed_events.py
@@ -0,0 +1,62 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +import logging +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +from twisted.web.server import Request + +from synapse.http.server import HttpServer +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict, JsonMapping + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ReplicationAddedDelayedEventRestServlet(ReplicationEndpoint): + """Handle a delayed event being added by another worker. + + Request format: + + POST /_synapse/replication/delayed_event_added/ + + {} + """ + + NAME = "added_delayed_event" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.handler = hs.get_delayed_events_handler() + + @staticmethod + async def _serialize_payload(next_send_ts: int) -> JsonDict: # type: ignore[override] + return {"next_send_ts": next_send_ts} + + async def _handle_request( # type: ignore[override] + self, request: Request, content: JsonDict + ) -> Tuple[int, Dict[str, Optional[JsonMapping]]]: + self.handler.on_added(int(content["next_send_ts"])) + + return 200, {} + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + ReplicationAddedDelayedEventRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 9c537427df..940f418396 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py
@@ -119,7 +119,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): return payload - async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override] + async def _handle_request( # type: ignore[override] + self, request: Request, content: JsonDict + ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_fed_send_events_parse"): room_id = content["room_id"] backfilled = content["backfilled"] diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py
index de07e75b46..48e254cdb1 100644 --- a/synapse/replication/http/push.py +++ b/synapse/replication/http/push.py
@@ -48,7 +48,7 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint): """ - NAME = "add_user_account_data" + NAME = "remove_pusher" PATH_ARGS = ("user_id",) CACHE = False @@ -98,7 +98,9 @@ class ReplicationCopyPusherRestServlet(ReplicationEndpoint): self._store = hs.get_datastores().main @staticmethod - async def _serialize_payload(user_id: str, old_room_id: str, new_room_id: str) -> JsonDict: # type: ignore[override] + async def _serialize_payload( # type: ignore[override] + user_id: str, old_room_id: str, new_room_id: str + ) -> JsonDict: return {} async def _handle_request( # type: ignore[override] @@ -109,7 +111,6 @@ class ReplicationCopyPusherRestServlet(ReplicationEndpoint): old_room_id: str, new_room_id: str, ) -> Tuple[int, JsonDict]: - await self._store.copy_push_rules_from_room_to_room_for_user( old_room_id, new_room_id, user_id ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3dddbb70b4..0bd5478cd3 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -18,8 +18,8 @@ # [This file includes modifications made by New Vector Limited] # # -"""A replication client for use by synapse workers. -""" +"""A replication client for use by synapse workers.""" + import logging from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index b7a7e77597..6ab5356660 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py
@@ -23,6 +23,7 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are allowed to be sent by which side. """ + import abc import logging from typing import List, Optional, Tuple, Type, TypeVar @@ -494,7 +495,7 @@ class LockReleasedCommand(Command): class NewActiveTaskCommand(_SimpleCommand): - """Sent to inform instance handling background tasks that a new active task is available to run. + """Sent to inform instance handling background tasks that a new task is ready to run. Format:: diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 72a42cb6cc..1fafbb48c3 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -727,7 +727,7 @@ class ReplicationCommandHandler: ) -> None: """Called when get a new NEW_ACTIVE_TASK command.""" if self._task_scheduler: - self._task_scheduler.launch_task_by_id(cmd.data) + self._task_scheduler.on_new_task(cmd.data) def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" @@ -857,7 +857,7 @@ UpdateRow = TypeVar("UpdateRow") def _batch_updates( - updates: Iterable[Tuple[UpdateToken, UpdateRow]] + updates: Iterable[Tuple[UpdateToken, UpdateRow]], ) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]: """Collect stream updates with the same token together diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 4471cc8f0c..fb9c539122 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -23,6 +23,7 @@ protocols. An explanation of this protocol is available in docs/tcp_replication.md """ + import fcntl import logging import struct diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index c0329378ac..d647a2b332 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -18,8 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -"""The server side of the replication stream. -""" +"""The server side of the replication stream.""" import logging import random @@ -307,7 +306,7 @@ class ReplicationStreamer: def _batch_updates( - updates: List[Tuple[Token, StreamRow]] + updates: List[Tuple[Token, StreamRow]], ) -> List[Tuple[Optional[Token], StreamRow]]: """Takes a list of updates of form [(token, row)] and sets the token to None for all rows where the next row has the same token. This is used to diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index d021904de7..ebf5964d29 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -247,7 +247,7 @@ class _StreamFromIdGen(Stream): def current_token_without_instance( - current_token: Callable[[], int] + current_token: Callable[[], int], ) -> Callable[[str], int]: """Takes a current token callback function for a single writer stream that doesn't take an instance name parameter and wraps it in a function that diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ea0803dfc2..05b55fb033 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py
@@ -200,9 +200,9 @@ class EventsStream(_StreamFromIdGen): # we rely on get_all_new_forward_event_rows strictly honouring the limit, so # that we know it is safe to just take upper_limit = event_rows[-1][0]. - assert ( - len(event_rows) <= target_row_count - ), "get_all_new_forward_event_rows did not honour row limit" + assert len(event_rows) <= target_row_count, ( + "get_all_new_forward_event_rows did not honour row limit" + ) # if we hit the limit on event_updates, there's no point in going beyond the # last stream_id in the batch for the other sources. diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index c5cdc36955..00f108de08 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py
@@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2014-2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -29,8 +29,9 @@ from synapse.rest.client import ( account_validity, appservice_ping, auth, - auth_issuer, + auth_metadata, capabilities, + delayed_events, devices, directory, events, @@ -81,6 +82,7 @@ CLIENT_SERVLET_FUNCTIONS: Tuple[RegisterServletsFunc, ...] = ( room.register_deprecated_servlets, events.register_servlets, room.register_servlets, + delayed_events.register_servlets, login.register_servlets, profile.register_servlets, presence.register_servlets, @@ -119,7 +121,7 @@ CLIENT_SERVLET_FUNCTIONS: Tuple[RegisterServletsFunc, ...] = ( mutual_rooms.register_servlets, login_token_request.register_servlets, rendezvous.register_servlets, - auth_issuer.register_servlets, + auth_metadata.register_servlets, ) SERVLET_GROUPS: Dict[str, Iterable[RegisterServletsFunc]] = { @@ -185,7 +187,6 @@ class ClientRestResource(JsonResource): mutual_rooms.register_servlets, login_token_request.register_servlets, rendezvous.register_servlets, - auth_issuer.register_servlets, ]: continue diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index cdaee17451..b37bf3429b 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py
@@ -39,7 +39,7 @@ from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.handlers.pagination import PURGE_HISTORY_ACTION_NAME -from synapse.http.server import HttpServer, JsonResource +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin @@ -51,6 +51,7 @@ from synapse.rest.admin.background_updates import ( from synapse.rest.admin.devices import ( DeleteDevicesRestServlet, DeviceRestServlet, + DevicesGetRestServlet, DevicesRestServlet, ) from synapse.rest.admin.event_reports import ( @@ -86,6 +87,7 @@ from synapse.rest.admin.rooms import ( RoomStateRestServlet, RoomTimestampToEventRestServlet, ) +from synapse.rest.admin.scheduled_tasks import ScheduledTasksRestServlet from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.statistics import ( LargestRoomsStatistics, @@ -98,13 +100,16 @@ from synapse.rest.admin.users import ( DeactivateAccountRestServlet, PushersRestServlet, RateLimitRestServlet, + RedactUser, + RedactUserStatus, ResetPasswordRestServlet, SearchUsersRestServlet, ShadowBanRestServlet, SuspendAccountRestServlet, UserAdminServlet, UserByExternalId, - UserByThreePid, + UserInvitesCount, + UserJoinedRoomCount, UserMembershipRestServlet, UserRegisterServlet, UserReplaceMasterCrossSigningKeyRestServlet, @@ -201,8 +206,7 @@ class PurgeHistoryRestServlet(RestServlet): (stream, topo, _event_id) = r token = "t%d-%d" % (topo, stream) logger.info( - "[purge] purging up to token %s (received_ts %i => " - "stream_ordering %i)", + "[purge] purging up to token %s (received_ts %i => stream_ordering %i)", token, ts, stream_ordering, @@ -259,27 +263,24 @@ class PurgeHistoryStatusRestServlet(RestServlet): ######################################################################################## -class AdminRestResource(JsonResource): - """The REST resource which gets mounted at /_synapse/admin""" - - def __init__(self, hs: "HomeServer"): - JsonResource.__init__(self, hs, canonical_json=False) - register_servlets(hs, self) - - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: """ Register all the admin servlets. """ - # Admin servlets aren't registered on workers. + RoomRestServlet(hs).register(http_server) + + # Admin servlets below may not work on workers. if hs.config.worker.worker_app is not None: + # Some admin servlets can be mounted on workers when MSC3861 is enabled. + if hs.config.experimental.msc3861.enabled: + register_servlets_for_msc3861_delegation(hs, http_server) + return register_servlets_for_client_rest_resource(hs, http_server) BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) - RoomRestServlet(hs).register(http_server) RoomRestV2Servlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) DeleteRoomStatusByDeleteIdRestServlet(hs).register(http_server) @@ -318,7 +319,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomTimestampToEventRestServlet(hs).register(http_server) UserReplaceMasterCrossSigningKeyRestServlet(hs).register(http_server) UserByExternalId(hs).register(http_server) - UserByThreePid(hs).register(http_server) + RedactUser(hs).register(http_server) + RedactUserStatus(hs).register(http_server) + UserInvitesCount(hs).register(http_server) + UserJoinedRoomCount(hs).register(http_server) DeviceRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) @@ -328,8 +332,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: BackgroundUpdateRestServlet(hs).register(http_server) BackgroundUpdateStartJobRestServlet(hs).register(http_server) ExperimentalFeaturesRestServlet(hs).register(http_server) - if hs.config.experimental.msc3823_account_suspension: - SuspendAccountRestServlet(hs).register(http_server) + SuspendAccountRestServlet(hs).register(http_server) + ScheduledTasksRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( @@ -357,4 +361,16 @@ def register_servlets_for_client_rest_resource( ListMediaInRoom(hs).register(http_server) # don't add more things here: new servlets should only be exposed on - # /_synapse/admin so should not go here. Instead register them in AdminRestResource. + # /_synapse/admin so should not go here. Instead register them in register_servlets. + + +def register_servlets_for_msc3861_delegation( + hs: "HomeServer", http_server: HttpServer +) -> None: + """Register servlets needed by MAS when MSC3861 is enabled""" + assert hs.config.experimental.msc3861.enabled + + UserRestServletV2(hs).register(http_server) + UsernameAvailableRestServlet(hs).register(http_server) + UserReplaceMasterCrossSigningKeyRestServlet(hs).register(http_server) + DevicesGetRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 449b066923..09baf8ce21 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py
@@ -113,18 +113,19 @@ class DeviceRestServlet(RestServlet): return HTTPStatus.OK, {} -class DevicesRestServlet(RestServlet): +class DevicesGetRestServlet(RestServlet): """ Retrieve the given user's devices + + This can be mounted on workers as it is read-only, as opposed + to `DevicesRestServlet`. """ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) - self.device_handler = handler + self.device_worker_handler = hs.get_device_handler() self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -141,9 +142,35 @@ class DevicesRestServlet(RestServlet): if u is None: raise NotFoundError("Unknown user") - devices = await self.device_handler.get_devices_by_user(target_user.to_string()) + devices = await self.device_worker_handler.get_devices_by_user( + target_user.to_string() + ) + + # mark the dehydrated device by adding a "dehydrated" flag + dehydrated_device_info = await self.device_worker_handler.get_dehydrated_device( + target_user.to_string() + ) + if dehydrated_device_info: + dehydrated_device_id = dehydrated_device_info[0] + for device in devices: + is_dehydrated = device["device_id"] == dehydrated_device_id + device["dehydrated"] = is_dehydrated + return HTTPStatus.OK, {"devices": devices, "total": len(devices)} + +class DevicesRestServlet(DevicesGetRestServlet): + """ + Retrieve the given user's devices + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + assert isinstance(self.device_worker_handler, DeviceHandler) + self.device_handler = self.device_worker_handler + async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index 9fb68bfa46..ff1abc0697 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py
@@ -50,8 +50,10 @@ class EventReportsRestServlet(RestServlet): The parameters `from` and `limit` are required only for pagination. By default, a `limit` of 100 is used. The parameter `dir` can be used to define the order of results. - The parameter `user_id` can be used to filter by user id. - The parameter `room_id` can be used to filter by room id. + The `user_id` query parameter filters by the user ID of the reporter of the event. + The `room_id` query parameter filters by room id. + The `event_sender_user_id` query parameter can be used to filter by the user id + of the sender of the reported event. Returns: A list of reported events and an integer representing the total number of reported events that exist given this query @@ -71,6 +73,7 @@ class EventReportsRestServlet(RestServlet): direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS) user_id = parse_string(request, "user_id") room_id = parse_string(request, "room_id") + event_sender_user_id = parse_string(request, "event_sender_user_id") if start < 0: raise SynapseError( @@ -87,7 +90,7 @@ class EventReportsRestServlet(RestServlet): ) event_reports, total = await self._store.get_event_reports_paginate( - start, limit, direction, user_id, room_id + start, limit, direction, user_id, room_id, event_sender_user_id ) ret = {"event_reports": event_reports, "total": total} if (start + limit) < total: diff --git a/synapse/rest/admin/experimental_features.py b/synapse/rest/admin/experimental_features.py
index d7913896d9..afb71f4a0f 100644 --- a/synapse/rest/admin/experimental_features.py +++ b/synapse/rest/admin/experimental_features.py
@@ -43,12 +43,15 @@ class ExperimentalFeature(str, Enum): MSC3881 = "msc3881" MSC3575 = "msc3575" + MSC4222 = "msc4222" def is_globally_enabled(self, config: "HomeServerConfig") -> bool: if self is ExperimentalFeature.MSC3881: return config.experimental.msc3881_enabled if self is ExperimentalFeature.MSC3575: return config.experimental.msc3575_enabled + if self is ExperimentalFeature.MSC4222: + return config.experimental.msc4222_enabled assert_never(self) diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index 0867f7a51c..bec2331590 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py
@@ -181,8 +181,7 @@ class NewRegistrationTokenRestServlet(RestServlet): uses_allowed = body.get("uses_allowed", None) if not ( - uses_allowed is None - or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 + uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 ): raise SynapseError( HTTPStatus.BAD_REQUEST, diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 01f9de9ffa..adac1f0362 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -23,6 +23,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple, cast import attr +from immutabledict import immutabledict from synapse.api.constants import Direction, EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError @@ -149,6 +150,7 @@ class RoomRestV2Servlet(RestServlet): def _convert_delete_task_to_response(task: ScheduledTask) -> JsonDict: return { "delete_id": task.id, + "room_id": task.resource_id, "status": task.status, "shutdown_room": task.result, } @@ -249,6 +251,10 @@ class ListRoomRestServlet(RestServlet): direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) reverse_order = True if direction == Direction.BACKWARDS else False + emma_include_tombstone = parse_boolean( + request, "emma_include_tombstone", default=False + ) + # Return list of rooms according to parameters rooms, total_rooms = await self.store.get_rooms_paginate( start, @@ -258,6 +264,7 @@ class ListRoomRestServlet(RestServlet): search_term, public_rooms, empty_rooms, + emma_include_tombstone = emma_include_tombstone ) response = { @@ -463,7 +470,18 @@ class RoomStateRestServlet(RestServlet): if not room: raise NotFoundError("Room not found") - event_ids = await self._storage_controllers.state.get_current_state_ids(room_id) + state_filter = None + type = parse_string(request, "type") + + if type: + state_filter = StateFilter( + types=immutabledict({type: None}), + include_others=False, + ) + + event_ids = await self._storage_controllers.state.get_current_state_ids( + room_id, state_filter + ) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() room_state = await self._event_serializer.serialize_events(events.values(), now) diff --git a/synapse/rest/admin/scheduled_tasks.py b/synapse/rest/admin/scheduled_tasks.py new file mode 100644
index 0000000000..2ae13021b9 --- /dev/null +++ b/synapse/rest/admin/scheduled_tasks.py
@@ -0,0 +1,70 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# +# +# +from typing import TYPE_CHECKING, Tuple + +from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.rest.admin import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict, TaskStatus + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class ScheduledTasksRestServlet(RestServlet): + """Get a list of scheduled tasks and their statuses + optionally filtered by action name, resource id, status, and max timestamp + """ + + PATTERNS = admin_patterns("/scheduled_tasks$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastores().main + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + # extract query params + action_name = parse_string(request, "action_name") + resource_id = parse_string(request, "resource_id") + status = parse_string(request, "job_status") + max_timestamp = parse_integer(request, "max_timestamp") + + actions = [action_name] if action_name else None + statuses = [TaskStatus(status)] if status else None + + tasks = await self._store.get_scheduled_tasks( + actions=actions, + resource_id=resource_id, + statuses=statuses, + max_timestamp=max_timestamp, + ) + + json_tasks = [] + for task in tasks: + result_task = { + "id": task.id, + "action": task.action, + "status": task.status, + "timestamp_ms": task.timestamp, + "resource_id": task.resource_id, + "result": task.result, + "error": task.error, + } + json_tasks.append(result_task) + + return 200, {"scheduled_tasks": json_tasks} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index ad515bd5a3..7671e020e0 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -27,8 +27,8 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr -from synapse._pydantic_compat import HAS_PYDANTIC_V2 -from synapse.api.constants import Direction, UserTypes +from synapse._pydantic_compat import StrictBool, StrictInt, StrictStr +from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, @@ -50,17 +50,12 @@ from synapse.rest.admin._base import ( from synapse.rest.client._base import client_patterns from synapse.storage.databases.main.registration import ExternalIDReuseException from synapse.storage.databases.main.stats import UserSortOrder -from synapse.types import JsonDict, JsonMapping, UserID +from synapse.types import JsonDict, JsonMapping, TaskStatus, UserID from synapse.types.rest import RequestBodyModel if TYPE_CHECKING: from synapse.server import HomeServer -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import StrictBool -else: - from pydantic import StrictBool - logger = logging.getLogger(__name__) @@ -235,6 +230,7 @@ class UserRestServletV2(RestServlet): self.registration_handler = hs.get_registration_handler() self.pusher_pool = hs.get_pusherpool() self._msc3866_enabled = hs.config.experimental.msc3866.enabled + self._all_user_types = hs.config.user_types.all_user_types async def on_GET( self, request: SynapseRequest, user_id: str @@ -269,12 +265,6 @@ class UserRestServletV2(RestServlet): user = await self.admin_handler.get_user(target_user) user_id = target_user.to_string() - # check for required parameters for each threepid - threepids = body.get("threepids") - if threepids is not None: - for threepid in threepids: - assert_params_in_dict(threepid, ["medium", "address"]) - # check for required parameters for each external_id external_ids = body.get("external_ids") if external_ids is not None: @@ -282,7 +272,7 @@ class UserRestServletV2(RestServlet): assert_params_in_dict(external_id, ["auth_provider", "external_id"]) user_type = body.get("user_type", None) - if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: + if user_type is not None and user_type not in self._all_user_types: raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") set_admin_to = body.get("admin", False) @@ -338,51 +328,12 @@ class UserRestServletV2(RestServlet): for external_id in external_ids ] - # convert List[Dict[str, str]] into Set[Tuple[str, str]] - if threepids is not None: - new_threepids = { - (threepid["medium"], threepid["address"]) for threepid in threepids - } - if user: # modify user if "displayname" in body: await self.profile_handler.set_displayname( target_user, requester, body["displayname"], True ) - if threepids is not None: - # get changed threepids (added and removed) - cur_threepids = { - (threepid.medium, threepid.address) - for threepid in await self.store.user_get_threepids(user_id) - } - add_threepids = new_threepids - cur_threepids - del_threepids = cur_threepids - new_threepids - - # remove old threepids - for medium, address in del_threepids: - try: - # Attempt to remove any known bindings of this third-party ID - # and user ID from identity servers. - await self.hs.get_identity_handler().try_unbind_threepid( - user_id, medium, address, id_server=None - ) - except Exception: - logger.exception("Failed to remove threepids") - raise SynapseError(500, "Failed to remove threepids") - - # Delete the local association of this user ID and third-party ID. - await self.auth_handler.delete_local_threepid( - user_id, medium, address - ) - - # add new threepids - current_time = self.hs.get_clock().time_msec() - for medium, address in add_threepids: - await self.auth_handler.add_threepid( - user_id, medium, address, current_time - ) - if external_ids is not None: try: await self.store.replace_user_external_id( @@ -467,28 +418,6 @@ class UserRestServletV2(RestServlet): approved=new_user_approved, ) - if threepids is not None: - current_time = self.hs.get_clock().time_msec() - for medium, address in new_threepids: - await self.auth_handler.add_threepid( - user_id, medium, address, current_time - ) - if ( - self.hs.config.email.email_enable_notifs - and self.hs.config.email.email_notif_for_new_users - and medium == "email" - ): - await self.pusher_pool.add_or_update_pusher( - user_id=user_id, - kind="email", - app_id="m.email", - app_display_name="Email Notifications", - device_display_name=address, - pushkey=address, - lang=None, - data={}, - ) - if external_ids is not None: try: for auth_provider, external_id in new_external_ids: @@ -529,6 +458,7 @@ class UserRegisterServlet(RestServlet): self.reactor = hs.get_reactor() self.nonces: Dict[str, int] = {} self.hs = hs + self._all_user_types = hs.config.user_types.all_user_types def _clear_old_nonces(self) -> None: """ @@ -610,7 +540,7 @@ class UserRegisterServlet(RestServlet): user_type = body.get("user_type", None) displayname = body.get("displayname", None) - if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: + if user_type is not None and user_type not in self._all_user_types: raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type") if "mac" not in body: @@ -988,7 +918,7 @@ class UserAdminServlet(RestServlet): class UserMembershipRestServlet(RestServlet): """ - Get room list of an user. + Get list of joined room ID's for a user. """ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/joined_rooms$") @@ -1004,8 +934,9 @@ class UserMembershipRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) room_ids = await self.store.get_rooms_for_user(user_id) - ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} - return HTTPStatus.OK, ret + rooms_response = {"joined_rooms": list(room_ids), "total": len(room_ids)} + + return HTTPStatus.OK, rooms_response class PushersRestServlet(RestServlet): @@ -1387,26 +1318,144 @@ class UserByExternalId(RestServlet): return HTTPStatus.OK, {"user_id": user_id} -class UserByThreePid(RestServlet): - """Find a user based on 3PID of a particular medium""" +class RedactUser(RestServlet): + """ + Redact all the events of a given user in the given rooms or if empty dict is provided + then all events in all rooms user is member of. Kicks off a background process and + returns an id that can be used to check on the progress of the redaction progress + """ - PATTERNS = admin_patterns("/threepid/(?P<medium>[^/]*)/users/(?P<address>[^/]*)") + PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/redact") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() self._store = hs.get_datastores().main + self.admin_handler = hs.get_admin_handler() + + class PostBody(RequestBodyModel): + rooms: List[StrictStr] + reason: Optional[StrictStr] + limit: Optional[StrictInt] + + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request) + await assert_user_is_admin(self._auth, requester) + + # parse provided user id to check that it is valid + UserID.from_string(user_id) + + body = parse_and_validate_json_object_from_request(request, self.PostBody) + + limit = body.limit + if limit and limit <= 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "If limit is provided it must be a non-negative integer greater than 0.", + ) + + rooms = body.rooms + if not rooms: + current_rooms = list(await self._store.get_rooms_for_user(user_id)) + banned_rooms = list( + await self._store.get_rooms_user_currently_banned_from(user_id) + ) + rooms = current_rooms + banned_rooms + + redact_id = await self.admin_handler.start_redact_events( + user_id, rooms, requester.serialize(), body.reason, limit + ) + + return HTTPStatus.OK, {"redact_id": redact_id} + + +class RedactUserStatus(RestServlet): + """ + Check on the progress of the redaction request represented by the provided ID, returning + the status of the process and a dict of events that were unable to be redacted, if any + """ + + PATTERNS = admin_patterns("/user/redact_status/(?P<redact_id>[^/]*)$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self.admin_handler = hs.get_admin_handler() async def on_GET( - self, - request: SynapseRequest, - medium: str, - address: str, + self, request: SynapseRequest, redact_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) - user_id = await self._store.get_user_id_by_threepid(medium, address) + task = await self.admin_handler.get_redact_task(redact_id) + + if task: + if task.status == TaskStatus.ACTIVE: + return HTTPStatus.OK, {"status": TaskStatus.ACTIVE} + elif task.status == TaskStatus.COMPLETE: + assert task.result is not None + failed_redactions = task.result.get("failed_redactions") + return HTTPStatus.OK, { + "status": TaskStatus.COMPLETE, + "failed_redactions": failed_redactions if failed_redactions else {}, + } + elif task.status == TaskStatus.SCHEDULED: + return HTTPStatus.OK, {"status": TaskStatus.SCHEDULED} + else: + return HTTPStatus.OK, { + "status": TaskStatus.FAILED, + "error": ( + task.error + if task.error + else "Unknown error, please check the logs for more information." + ), + } + else: + raise NotFoundError("redact id '%s' not found" % redact_id) - if user_id is None: - raise NotFoundError("User not found") - return HTTPStatus.OK, {"user_id": user_id} +class UserInvitesCount(RestServlet): + """ + Return the count of invites that the user has sent after the given timestamp + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/sent_invite_count") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self.store = hs.get_datastores().main + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + from_ts = parse_integer(request, "from_ts", required=True) + + sent_invite_count = await self.store.get_sent_invite_count_by_user( + user_id, from_ts + ) + + return HTTPStatus.OK, {"invite_count": sent_invite_count} + + +class UserJoinedRoomCount(RestServlet): + """ + Return the count of rooms that the user has joined at or after the given timestamp, even + if they have subsequently left/been banned from those rooms. + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/cumulative_joined_room_count") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self.store = hs.get_datastores().main + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + from_ts = parse_integer(request, "from_ts", required=True) + + joined_rooms = await self.store.get_rooms_for_user_by_date(user_id, from_ts) + + return HTTPStatus.OK, {"cumulative_joined_room_count": len(joined_rooms)} diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py
index 93dec6375a..6cf37869d8 100644 --- a/synapse/rest/client/_base.py +++ b/synapse/rest/client/_base.py
@@ -19,8 +19,8 @@ # # -"""This module contains base REST classes for constructing client v1 servlets. -""" +"""This module contains base REST classes for constructing client v1 servlets.""" + import logging import re from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 8daa449f9e..455ddda484 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py
@@ -21,28 +21,20 @@ # import logging import random -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Literal, Optional, Tuple from urllib.parse import urlparse -from synapse._pydantic_compat import HAS_PYDANTIC_V2 - -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import StrictBool, StrictStr, constr -else: - from pydantic import StrictBool, StrictStr, constr - import attr -from typing_extensions import Literal from twisted.web.server import Request +from synapse._pydantic_compat import StrictBool, StrictStr, constr from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, InteractiveAuthIncompleteError, NotFoundError, SynapseError, - ThreepidValidationError, ) from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http.server import HttpServer, finish_request, respond_with_html @@ -54,19 +46,13 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest -from synapse.metrics import threepid_send_requests -from synapse.push.mailer import Mailer from synapse.types import JsonDict from synapse.types.rest import RequestBodyModel from synapse.types.rest.client import ( AuthenticationData, - ClientSecretStr, - EmailRequestTokenBody, - MsisdnRequestTokenBody, + ClientSecretStr ) -from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import check_3pid_allowed, validate_email from ._base import client_patterns, interactive_auth_handler @@ -77,80 +63,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class EmailPasswordRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/account/password/email/requestToken$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.datastore = hs.get_datastores().main - self.config = hs.config - self.identity_handler = hs.get_identity_handler() - - if self.config.email.can_verify_email: - self.mailer = Mailer( - hs=self.hs, - app_name=self.config.email.email_app_name, - template_html=self.config.email.email_password_reset_template_html, - template_text=self.config.email.email_password_reset_template_text, - ) - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.config.email.can_verify_email: - logger.warning( - "User password resets have been disabled due to lack of email config" - ) - raise SynapseError( - 400, "Email-based password resets have been disabled on this server" - ) - - body = parse_and_validate_json_object_from_request( - request, EmailRequestTokenBody - ) - - if body.next_link: - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, body.next_link) - - await self.identity_handler.ratelimit_request_token_requests( - request, "email", body.email - ) - - # The email will be sent to the stored address. - # This avoids a potential account hijack by requesting a password reset to - # an email address which is controlled by the attacker but which, after - # canonicalisation, matches the one in our database. - existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( - "email", body.email - ) - - if existing_user_id is None: - if self.config.server.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) - - # Send password reset emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - body.email, - body.client_secret, - body.send_attempt, - self.mailer.send_password_reset_mail, - body.next_link, - ) - threepid_send_requests.labels(type="email", reason="password_reset").observe( - body.send_attempt - ) - - # Wrap the session id in a JSON object - return 200, {"sid": sid} - - class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") @@ -211,30 +123,8 @@ class PasswordRestServlet(RestServlet): "modify your account password", ) - if LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - if "medium" not in threepid or "address" not in threepid: - raise SynapseError(500, "Malformed threepid") - if threepid["medium"] == "email": - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See add_threepid in synapse/handlers/auth.py) - try: - threepid["address"] = validate_email(threepid["address"]) - except ValueError as e: - raise SynapseError(400, str(e)) - # if using email, we must know about the email they're authing with! - threepid_user_id = await self.datastore.get_user_id_by_threepid( - threepid["medium"], threepid["address"] - ) - if not threepid_user_id: - raise SynapseError( - 404, "Email address not found", Codes.NOT_FOUND - ) - user_id = threepid_user_id - else: - logger.error("Auth succeeded but no known type! %r", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) + logger.error("Auth succeeded but no known type (hint: 3PID auth was removed)! %r", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) except InteractiveAuthIncompleteError as e: # The user needs to provide more steps to complete auth, but @@ -326,486 +216,6 @@ class DeactivateAccountRestServlet(RestServlet): return 200, {"id_server_unbind_result": id_server_unbind_result} -class EmailThreepidRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/email/requestToken$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.config = hs.config - self.identity_handler = hs.get_identity_handler() - self.store = self.hs.get_datastores().main - - if self.config.email.can_verify_email: - self.mailer = Mailer( - hs=self.hs, - app_name=self.config.email.email_app_name, - template_html=self.config.email.email_add_threepid_template_html, - template_text=self.config.email.email_add_threepid_template_text, - ) - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.registration.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - - if not self.config.email.can_verify_email: - logger.warning( - "Adding emails have been disabled due to lack of an email config" - ) - raise SynapseError( - 400, - "Adding an email to your account is disabled on this server", - ) - - body = parse_and_validate_json_object_from_request( - request, EmailRequestTokenBody - ) - - if not await check_3pid_allowed(self.hs, "email", body.email): - raise SynapseError( - 403, - "Your email domain is not authorized on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "email", body.email - ) - - if body.next_link: - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, body.next_link) - - existing_user_id = await self.store.get_user_id_by_threepid("email", body.email) - - if existing_user_id is not None: - if self.config.server.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - - # Send threepid validation emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - body.email, - body.client_secret, - body.send_attempt, - self.mailer.send_add_threepid_mail, - body.next_link, - ) - - threepid_send_requests.labels(type="email", reason="add_threepid").observe( - body.send_attempt - ) - - # Wrap the session id in a JSON object - return 200, {"sid": sid} - - -class MsisdnThreepidRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") - - def __init__(self, hs: "HomeServer"): - self.hs = hs - super().__init__() - self.store = self.hs.get_datastores().main - self.identity_handler = hs.get_identity_handler() - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - body = parse_and_validate_json_object_from_request( - request, MsisdnRequestTokenBody - ) - msisdn = phone_number_to_msisdn(body.country, body.phone_number) - logger.info("Request #%s to verify ownership of %s", body.send_attempt, msisdn) - - if not await check_3pid_allowed(self.hs, "msisdn", msisdn): - raise SynapseError( - 403, - # TODO: is this error message accurate? Looks like we've only rejected - # this phone number, not necessarily all phone numbers - "Account phone numbers are not authorized on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "msisdn", msisdn - ) - - if body.next_link: - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, body.next_link) - - existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) - - if existing_user_id is not None: - if self.hs.config.server.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - logger.info("MSISDN %s is already in use by %s", msisdn, existing_user_id) - raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) - - if not self.hs.config.registration.account_threepid_delegate_msisdn: - logger.warning( - "No upstream msisdn account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, - "Adding phone numbers to user account is not supported by this homeserver", - ) - - ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.registration.account_threepid_delegate_msisdn, - body.country, - body.phone_number, - body.client_secret, - body.send_attempt, - body.next_link, - ) - - threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe( - body.send_attempt - ) - logger.info("MSISDN %s: got response from identity server: %s", msisdn, ret) - - return 200, ret - - -class AddThreepidEmailSubmitTokenServlet(RestServlet): - """Handles 3PID validation token submission for adding an email to a user's account""" - - PATTERNS = client_patterns( - "/add_threepid/email/submit_token$", releases=(), unstable=True - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - if self.config.email.can_verify_email: - self._failure_email_template = ( - self.config.email.email_add_threepid_template_failure_html - ) - - async def on_GET(self, request: Request) -> None: - if not self.config.email.can_verify_email: - logger.warning( - "Adding emails have been disabled due to lack of an email config" - ) - raise SynapseError( - 400, "Adding an email to your account is disabled on this server" - ) - - sid = parse_string(request, "sid", required=True) - token = parse_string(request, "token", required=True) - client_secret = parse_string(request, "client_secret", required=True) - assert_valid_client_secret(client_secret) - - # Attempt to validate a 3PID session - try: - # Mark the session as valid - next_link = await self.store.validate_threepid_session( - sid, client_secret, token, self.clock.time_msec() - ) - - # Perform a 302 redirect if next_link is set - if next_link: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None - - # Otherwise show the success template - html = self.config.email.email_add_threepid_template_success_html_content - status_code = 200 - except ThreepidValidationError as e: - status_code = e.code - - # Show a failure page with a reason - template_vars = {"failure_reason": e.msg} - html = self._failure_email_template.render(**template_vars) - - respond_with_html(request, status_code, html) - - -class AddThreepidMsisdnSubmitTokenServlet(RestServlet): - """Handles 3PID validation token submission for adding a phone number to a user's - account - """ - - PATTERNS = client_patterns( - "/add_threepid/msisdn/submit_token$", releases=(), unstable=True - ) - - class PostBody(RequestBodyModel): - client_secret: ClientSecretStr - sid: StrictStr - token: StrictStr - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - self.identity_handler = hs.get_identity_handler() - - async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - if not self.config.registration.account_threepid_delegate_msisdn: - raise SynapseError( - 400, - "This homeserver is not validating phone numbers. Use an identity server " - "instead.", - ) - - body = parse_and_validate_json_object_from_request(request, self.PostBody) - - # Proxy submit_token request to msisdn threepid delegate - response = await self.identity_handler.proxy_msisdn_submit_token( - self.config.registration.account_threepid_delegate_msisdn, - body.client_secret, - body.sid, - body.token, - ) - return 200, response - - -class ThreepidRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid$") - # This is used as a proxy for all the 3pid endpoints. - - CATEGORY = "Client API requests" - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastores().main - - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - - threepids = await self.datastore.user_get_threepids(requester.user.to_string()) - - return 200, {"threepids": [attr.asdict(t) for t in threepids]} - - # NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because - # the endpoint is deprecated. (If you really want to, you could do this by reusing - # ThreePidBindRestServelet.PostBody with an `alias_generator` to handle - # `threePidCreds` versus `three_pid_creds`. - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.hs.config.experimental.msc3861.enabled: - raise NotFoundError(errcode=Codes.UNRECOGNIZED) - - if not self.hs.config.registration.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") - if threepid_creds is None: - raise SynapseError( - 400, "Missing param three_pid_creds", Codes.MISSING_PARAM - ) - assert_params_in_dict(threepid_creds, ["client_secret", "sid"]) - - sid = threepid_creds["sid"] - client_secret = threepid_creds["client_secret"] - assert_valid_client_secret(client_secret) - - validation_session = await self.identity_handler.validate_threepid_session( - client_secret, sid - ) - if validation_session: - await self.auth_handler.add_threepid( - user_id, - validation_session["medium"], - validation_session["address"], - validation_session["validated_at"], - ) - return 200, {} - - raise SynapseError( - 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED - ) - - -class ThreepidAddRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/add$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - - class PostBody(RequestBodyModel): - auth: Optional[AuthenticationData] = None - client_secret: ClientSecretStr - sid: StrictStr - - @interactive_auth_handler - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.registration.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - body = parse_and_validate_json_object_from_request(request, self.PostBody) - - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body.dict(exclude_unset=True), - "add a third-party identifier to your account", - ) - - validation_session = await self.identity_handler.validate_threepid_session( - body.client_secret, body.sid - ) - if validation_session: - await self.auth_handler.add_threepid( - user_id, - validation_session["medium"], - validation_session["address"], - validation_session["validated_at"], - ) - return 200, {} - - raise SynapseError( - 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED - ) - - -class ThreepidBindRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/bind$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - - class PostBody(RequestBodyModel): - client_secret: ClientSecretStr - id_access_token: StrictStr - id_server: StrictStr - sid: StrictStr - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - body = parse_and_validate_json_object_from_request(request, self.PostBody) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - await self.identity_handler.bind_threepid( - body.client_secret, body.sid, user_id, body.id_server, body.id_access_token - ) - - return 200, {} - - -class ThreepidUnbindRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/unbind$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - self.datastore = self.hs.get_datastores().main - - class PostBody(RequestBodyModel): - address: StrictStr - id_server: Optional[StrictStr] = None - medium: Literal["email", "msisdn"] - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - """Unbind the given 3pid from a specific identity server, or identity servers that are - known to have this 3pid bound - """ - requester = await self.auth.get_user_by_req(request) - body = parse_and_validate_json_object_from_request(request, self.PostBody) - - # Attempt to unbind the threepid from an identity server. If id_server is None, try to - # unbind from all identity servers this threepid has been added to in the past - result = await self.identity_handler.try_unbind_threepid( - requester.user.to_string(), body.medium, body.address, body.id_server - ) - return 200, {"id_server_unbind_result": "success" if result else "no-support"} - - -class ThreepidDeleteRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/delete$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - - class PostBody(RequestBodyModel): - address: StrictStr - id_server: Optional[StrictStr] = None - medium: Literal["email", "msisdn"] - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.registration.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - - body = parse_and_validate_json_object_from_request(request, self.PostBody) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - try: - # Attempt to remove any known bindings of this third-party ID - # and user ID from identity servers. - ret = await self.hs.get_identity_handler().try_unbind_threepid( - user_id, body.medium, body.address, body.id_server - ) - except Exception: - # NB. This endpoint should succeed if there is nothing to - # delete, so it should only throw if something is wrong - # that we ought to care about. - logger.exception("Failed to remove threepid") - raise SynapseError(500, "Failed to remove threepid") - - if ret: - id_server_unbind_result = "success" - else: - id_server_unbind_result = "no-support" - - # Delete the local association of this user ID and third-party ID. - await self.auth_handler.delete_local_threepid( - user_id, body.medium, body.address - ) - - return 200, {"id_server_unbind_result": id_server_unbind_result} - - def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None: """ Raises a SynapseError if a given next_link value is invalid @@ -901,20 +311,8 @@ class AccountStatusRestServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.worker.worker_app is None: if not hs.config.experimental.msc3861.enabled: - EmailPasswordRequestTokenRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) - EmailThreepidRequestTokenRestServlet(hs).register(http_server) - MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) - AddThreepidEmailSubmitTokenServlet(hs).register(http_server) - AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server) - ThreepidRestServlet(hs).register(http_server) - if hs.config.worker.worker_app is None: - ThreepidBindRestServlet(hs).register(http_server) - ThreepidUnbindRestServlet(hs).register(http_server) - if not hs.config.experimental.msc3861.enabled: - ThreepidAddRestServlet(hs).register(http_server) - ThreepidDeleteRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) if hs.config.worker.worker_app is None and hs.config.experimental.msc3720_enabled: diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index 0ee24081fa..734c9e992f 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py
@@ -108,9 +108,9 @@ class AccountDataServlet(RestServlet): # Push rules are stored in a separate table and must be queried separately. if account_data_type == AccountDataTypes.PUSH_RULES: - account_data: Optional[JsonMapping] = ( - await self._push_rules_handler.push_rules_for_user(requester.user) - ) + account_data: Optional[ + JsonMapping + ] = await self._push_rules_handler.push_rules_for_user(requester.user) else: account_data = await self.store.get_global_account_data_by_type_for_user( user_id, account_data_type diff --git a/synapse/rest/client/account_validity.py b/synapse/rest/client/account_validity.py
index 6222a5cc37..ec7836b647 100644 --- a/synapse/rest/client/account_validity.py +++ b/synapse/rest/client/account_validity.py
@@ -48,9 +48,7 @@ class AccountValidityRenewServlet(RestServlet): self.account_renewed_template = ( hs.config.account_validity.account_validity_account_renewed_template ) - self.account_previously_renewed_template = ( - hs.config.account_validity.account_validity_account_previously_renewed_template - ) + self.account_previously_renewed_template = hs.config.account_validity.account_validity_account_previously_renewed_template self.invalid_token_template = ( hs.config.account_validity.account_validity_invalid_token_template ) diff --git a/synapse/rest/client/appservice_ping.py b/synapse/rest/client/appservice_ping.py
index d6b4e32453..1f9662a95a 100644 --- a/synapse/rest/client/appservice_ping.py +++ b/synapse/rest/client/appservice_ping.py
@@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2023 Tulir Asokan -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023, 2025 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -53,6 +53,7 @@ class AppservicePingRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.as_api = hs.get_application_service_api() + self.scheduler = hs.get_application_service_scheduler() self.auth = hs.get_auth() async def on_POST( @@ -85,6 +86,10 @@ class AppservicePingRestServlet(RestServlet): start = time.monotonic() try: await self.as_api.ping(requester.app_service, txn_id) + + # We got a OK response, so if the AS needs to be recovered then lets recover it now. + # This sets off a task in the background and so is safe to execute and forget. + self.scheduler.txn_ctrl.force_retry(requester.app_service) except RequestTimedOutError as e: raise SynapseError( HTTPStatus.GATEWAY_TIMEOUT, diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 4221f35937..b8dca7c797 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py
@@ -20,14 +20,14 @@ # import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from twisted.web.server import Request from synapse.api.constants import LoginType from synapse.api.errors import LoginError, SynapseError from synapse.api.urls import CLIENT_API_PREFIX -from synapse.http.server import HttpServer, respond_with_html +from synapse.http.server import HttpServer, respond_with_html, respond_with_redirect from synapse.http.servlet import RestServlet, parse_string from synapse.http.site import SynapseRequest @@ -66,6 +66,23 @@ class AuthRestServlet(RestServlet): if not session: raise SynapseError(400, "No session supplied") + if ( + self.hs.config.experimental.msc3861.enabled + and stagetype == "org.matrix.cross_signing_reset" + ): + # If MSC3861 is enabled, we can assume self._auth is an instance of MSC3861DelegatedAuth + # We import lazily here because of the authlib requirement + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + auth = cast(MSC3861DelegatedAuth, self.auth) + + url = await auth.account_management_url() + if url is not None: + url = f"{url}?action=org.matrix.cross_signing_reset" + else: + url = await auth.issuer() + respond_with_redirect(request, str.encode(url)) + if stagetype == LoginType.RECAPTCHA: html = self.recaptcha_template.render( session=session, diff --git a/synapse/rest/client/auth_issuer.py b/synapse/rest/client/auth_metadata.py
index 77b9720956..5444a89be6 100644 --- a/synapse/rest/client/auth_issuer.py +++ b/synapse/rest/client/auth_metadata.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging import typing -from typing import Tuple +from typing import Tuple, cast from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer @@ -32,6 +32,8 @@ logger = logging.getLogger(__name__) class AuthIssuerServlet(RestServlet): """ Advertises what OpenID Connect issuer clients should use to authorise users. + This endpoint was defined in a previous iteration of MSC2965, and is still + used by some clients. """ PATTERNS = client_patterns( @@ -43,10 +45,16 @@ class AuthIssuerServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._config = hs.config + self._auth = hs.get_auth() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self._config.experimental.msc3861.enabled: - return 200, {"issuer": self._config.experimental.msc3861.issuer} + # If MSC3861 is enabled, we can assume self._auth is an instance of MSC3861DelegatedAuth + # We import lazily here because of the authlib requirement + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + auth = cast(MSC3861DelegatedAuth, self._auth) + return 200, {"issuer": await auth.issuer()} else: # Wouldn't expect this to be reached: the servelet shouldn't have been # registered. Still, fail gracefully if we are registered for some reason. @@ -57,7 +65,42 @@ class AuthIssuerServlet(RestServlet): ) +class AuthMetadataServlet(RestServlet): + """ + Advertises the OAuth 2.0 server metadata for the homeserver. + """ + + PATTERNS = client_patterns( + "/org.matrix.msc2965/auth_metadata$", + unstable=True, + releases=(), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._config = hs.config + self._auth = hs.get_auth() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + if self._config.experimental.msc3861.enabled: + # If MSC3861 is enabled, we can assume self._auth is an instance of MSC3861DelegatedAuth + # We import lazily here because of the authlib requirement + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + auth = cast(MSC3861DelegatedAuth, self._auth) + return 200, await auth.auth_metadata() + else: + # Wouldn't expect this to be reached: the servlet shouldn't have been + # registered. Still, fail gracefully if we are registered for some reason. + raise SynapseError( + 404, + "OIDC discovery has not been configured on this homeserver", + Codes.NOT_FOUND, + ) + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: # We use the MSC3861 values as they are used by multiple MSCs if hs.config.experimental.msc3861.enabled: AuthIssuerServlet(hs).register(http_server) + AuthMetadataServlet(hs).register(http_server) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 63b8a9364a..caac5826a4 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py
@@ -21,7 +21,7 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest @@ -69,7 +69,7 @@ class CapabilitiesRestServlet(RestServlet): "enabled": self.config.registration.enable_set_avatar_url }, "m.3pid_changes": { - "enabled": self.config.registration.enable_3pid_changes + "enabled": False }, "m.get_login_token": { "enabled": self.config.auth.login_via_existing_enabled, @@ -77,11 +77,6 @@ class CapabilitiesRestServlet(RestServlet): } } - if self.config.experimental.msc3244_enabled: - response["capabilities"]["m.room_versions"][ - "org.matrix.msc3244.room_capabilities" - ] = MSC3244_CAPABILITIES - if self.config.experimental.msc3720_enabled: response["capabilities"]["org.matrix.msc3720.account_status"] = { "enabled": True, @@ -92,6 +87,23 @@ class CapabilitiesRestServlet(RestServlet): "enabled": self.config.experimental.msc3664_enabled, } + if self.config.experimental.msc4133_enabled: + response["capabilities"]["uk.tcpip.msc4133.profile_fields"] = { + "enabled": True, + } + + # Ensure this is consistent with the legacy m.set_displayname and + # m.set_avatar_url. + disallowed = [] + if not self.config.registration.enable_set_displayname: + disallowed.append("displayname") + if not self.config.registration.enable_set_avatar_url: + disallowed.append("avatar_url") + if disallowed: + response["capabilities"]["uk.tcpip.msc4133.profile_fields"][ + "disallowed" + ] = disallowed + return HTTPStatus.OK, response diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py new file mode 100644
index 0000000000..2dd5a60b2b --- /dev/null +++ b/synapse/rest/client/delayed_events.py
@@ -0,0 +1,111 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +# This module contains REST servlets to do with delayed events: /delayed_events/<paths> + +import logging +from enum import Enum +from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class _UpdateDelayedEventAction(Enum): + CANCEL = "cancel" + RESTART = "restart" + SEND = "send" + + +class UpdateDelayedEventServlet(RestServlet): + PATTERNS = client_patterns( + r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)$", + releases=(), + ) + CATEGORY = "Delayed event management requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.delayed_events_handler = hs.get_delayed_events_handler() + + async def on_POST( + self, request: SynapseRequest, delay_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + body = parse_json_object_from_request(request) + try: + action = str(body["action"]) + except KeyError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "'action' is missing", + Codes.MISSING_PARAM, + ) + try: + enum_action = _UpdateDelayedEventAction(action) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "'action' is not one of " + + ", ".join(f"'{m.value}'" for m in _UpdateDelayedEventAction), + Codes.INVALID_PARAM, + ) + + if enum_action == _UpdateDelayedEventAction.CANCEL: + await self.delayed_events_handler.cancel(requester, delay_id) + elif enum_action == _UpdateDelayedEventAction.RESTART: + await self.delayed_events_handler.restart(requester, delay_id) + elif enum_action == _UpdateDelayedEventAction.SEND: + await self.delayed_events_handler.send(requester, delay_id) + return 200, {} + + +class DelayedEventsServlet(RestServlet): + PATTERNS = client_patterns( + r"/org\.matrix\.msc4140/delayed_events$", + releases=(), + ) + CATEGORY = "Delayed event management requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.delayed_events_handler = hs.get_delayed_events_handler() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + # TODO: Support Pagination stream API ("from" query parameter) + delayed_events = await self.delayed_events_handler.get_all_for_user(requester) + + ret = {"delayed_events": delayed_events} + return 200, ret + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + # The following can't currently be instantiated on workers. + if hs.config.worker.worker_app is None: + UpdateDelayedEventServlet(hs).register(http_server) + DelayedEventsServlet(hs).register(http_server) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 8313d687b7..0b075cc2f2 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py
@@ -24,13 +24,7 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple -from synapse._pydantic_compat import HAS_PYDANTIC_V2 - -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import Extra, StrictStr -else: - from pydantic import Extra, StrictStr - +from synapse._pydantic_compat import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError from synapse.handlers.device import DeviceHandler @@ -120,15 +114,19 @@ class DeleteDevicesRestServlet(RestServlet): else: raise e - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body.dict(exclude_unset=True), - "remove device(s) from your account", - # Users might call this multiple times in a row while cleaning up - # devices, allow a single UI auth session to be re-used. - can_skip_ui_auth=True, - ) + if requester.app_service and requester.app_service.msc4190_device_management: + # MSC4190 can skip UIA for this endpoint + pass + else: + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body.dict(exclude_unset=True), + "remove device(s) from your account", + # Users might call this multiple times in a row while cleaning up + # devices, allow a single UI auth session to be re-used. + can_skip_ui_auth=True, + ) await self.device_handler.delete_devices( requester.user.to_string(), body.devices @@ -145,11 +143,11 @@ class DeviceRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() handler = hs.get_device_handler() - assert isinstance(handler, DeviceHandler) self.device_handler = handler self.auth_handler = hs.get_auth_handler() self._msc3852_enabled = hs.config.experimental.msc3852_enabled self._msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled + self._is_main_process = hs.config.worker.worker_app is None async def on_GET( self, request: SynapseRequest, device_id: str @@ -181,8 +179,13 @@ class DeviceRestServlet(RestServlet): async def on_DELETE( self, request: SynapseRequest, device_id: str ) -> Tuple[int, JsonDict]: - if self._msc3861_oauth_delegation_enabled: - raise UnrecognizedRequestError(code=404) + # Can only be run on main process, as changes to device lists must + # happen on main. + if not self._is_main_process: + error_message = "DELETE on /devices/ must be routed to main process" + logger.error(error_message) + raise SynapseError(500, error_message) + assert isinstance(self.device_handler, DeviceHandler) requester = await self.auth.get_user_by_req(request) @@ -198,15 +201,24 @@ class DeviceRestServlet(RestServlet): else: raise - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body.dict(exclude_unset=True), - "remove a device from your account", - # Users might call this multiple times in a row while cleaning up - # devices, allow a single UI auth session to be re-used. - can_skip_ui_auth=True, - ) + if requester.app_service and requester.app_service.msc4190_device_management: + # MSC4190 allows appservices to delete devices through this endpoint without UIA + # It's also allowed with MSC3861 enabled + pass + + else: + if self._msc3861_oauth_delegation_enabled: + raise UnrecognizedRequestError(code=404) + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body.dict(exclude_unset=True), + "remove a device from your account", + # Users might call this multiple times in a row while cleaning up + # devices, allow a single UI auth session to be re-used. + can_skip_ui_auth=True, + ) await self.device_handler.delete_devices( requester.user.to_string(), [device_id] @@ -219,9 +231,27 @@ class DeviceRestServlet(RestServlet): async def on_PUT( self, request: SynapseRequest, device_id: str ) -> Tuple[int, JsonDict]: + # Can only be run on main process, as changes to device lists must + # happen on main. + if not self._is_main_process: + error_message = "PUT on /devices/ must be routed to main process" + logger.error(error_message) + raise SynapseError(500, error_message) + assert isinstance(self.device_handler, DeviceHandler) + requester = await self.auth.get_user_by_req(request, allow_guest=True) body = parse_and_validate_json_object_from_request(request, self.PutBody) + + # MSC4190 allows appservices to create devices through this endpoint + if requester.app_service and requester.app_service.msc4190_device_management: + created = await self.device_handler.upsert_device( + user_id=requester.user.to_string(), + device_id=device_id, + display_name=body.display_name, + ) + return 201 if created else 200, {} + await self.device_handler.update_device( requester.user.to_string(), device_id, body.dict() ) @@ -571,9 +601,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ): DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) if hs.config.worker.worker_app is None: - DeviceRestServlet(hs).register(http_server) if hs.config.experimental.msc2697_enabled: DehydratedDeviceServlet(hs).register(http_server) ClaimDehydratedDeviceServlet(hs).register(http_server) diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index 11fdd0f7c6..479f489623 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py
@@ -20,19 +20,11 @@ # import logging -from typing import TYPE_CHECKING, List, Optional, Tuple - -from synapse._pydantic_compat import HAS_PYDANTIC_V2 - -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import StrictStr -else: - from pydantic import StrictStr - -from typing_extensions import Literal +from typing import TYPE_CHECKING, List, Literal, Optional, Tuple from twisted.web.server import Request +from synapse._pydantic_compat import StrictStr from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import ( diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 613890061e..ad23cc76ce 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py
@@ -20,6 +20,7 @@ # """This module contains REST servlets to do with event streaming, /events.""" + import logging from typing import TYPE_CHECKING, Dict, List, Tuple, Union diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index eddad7d5b8..7025662fdc 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py
@@ -23,10 +23,13 @@ import logging import re from collections import Counter -from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast -from synapse.api.errors import Codes, InvalidAPICallError, SynapseError +from synapse.api.errors import ( + InteractiveAuthIncompleteError, + InvalidAPICallError, + SynapseError, +) from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -403,17 +406,36 @@ class SigningKeyUploadServlet(RestServlet): # explicitly mark the master key as replaceable. if self.hs.config.experimental.msc3861.enabled: if not master_key_updatable_without_uia: - config = self.hs.config.experimental.msc3861 - if config.account_management_url is not None: - url = f"{config.account_management_url}?action=org.matrix.cross_signing_reset" + # If MSC3861 is enabled, we can assume self.auth is an instance of MSC3861DelegatedAuth + # We import lazily here because of the authlib requirement + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + auth = cast(MSC3861DelegatedAuth, self.auth) + + uri = await auth.account_management_url() + if uri is not None: + url = f"{uri}?action=org.matrix.cross_signing_reset" else: - url = config.issuer + url = await auth.issuer() - raise SynapseError( - HTTPStatus.NOT_IMPLEMENTED, - "To reset your end-to-end encryption cross-signing identity, " - f"you first need to approve it at {url} and then try again.", - Codes.UNRECOGNIZED, + # We use a dummy session ID as this isn't really a UIA flow, but we + # reuse the same API shape for better client compatibility. + raise InteractiveAuthIncompleteError( + "dummy", + { + "session": "dummy", + "flows": [ + {"stages": ["org.matrix.cross_signing_reset"]}, + ], + "params": { + "org.matrix.cross_signing_reset": { + "url": url, + }, + }, + "msg": "To reset your end-to-end encryption cross-signing " + f"identity, you first need to approve it at {url} and " + "then try again.", + }, ) else: # Without MSC3861, we require UIA. diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index e31687fc13..d7a17e1b35 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py
@@ -53,7 +53,6 @@ class KnockRoomAliasServlet(RestServlet): super().__init__() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - self._support_via = hs.config.experimental.msc4156_enabled async def on_POST( self, @@ -72,15 +71,11 @@ class KnockRoomAliasServlet(RestServlet): # twisted.web.server.Request.args is incorrectly defined as Optional[Any] args: Dict[bytes, List[bytes]] = request.args # type: ignore - remote_room_hosts = parse_strings_from_args( - args, "server_name", required=False - ) - if self._support_via: + # Prefer via over server_name (deprecated with MSC4156) + remote_room_hosts = parse_strings_from_args(args, "via", required=False) + if remote_room_hosts is None: remote_room_hosts = parse_strings_from_args( - args, - "org.matrix.msc4156.via", - default=remote_room_hosts, - required=False, + args, "server_name", required=False ) elif RoomAlias.is_valid(room_identifier): handler = self.room_member_handler diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index ae691bcdba..cc6863cadc 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py
@@ -30,11 +30,10 @@ from typing import ( List, Optional, Tuple, + TypedDict, Union, ) -from typing_extensions import TypedDict - from synapse.api.constants import ApprovalNoticeMedium from synapse.api.errors import ( Codes, @@ -82,7 +81,6 @@ class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CATEGORY = "Registration/login requests" - CAS_TYPE = "m.login.cas" SSO_TYPE = "m.login.sso" TOKEN_TYPE = "m.login.token" JWT_TYPE = "org.matrix.login.jwt" @@ -98,8 +96,6 @@ class LoginRestServlet(RestServlet): self.jwt_enabled = hs.config.jwt.jwt_enabled # SSO configuration. - self.saml2_enabled = hs.config.saml2.saml2_enabled - self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None @@ -136,7 +132,7 @@ class LoginRestServlet(RestServlet): cfg=self.hs.config.ratelimiting.rc_login_account, ) - # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. + # ensure the OIDC handlers are loaded on this worker instance. # The reason for this is to ensure that the auth_provider_ids are registered # with SsoHandler, which in turn ensures that the login/registration prometheus # counters are initialised for the auth_provider_ids. @@ -147,15 +143,10 @@ class LoginRestServlet(RestServlet): if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) - if self.cas_enabled: - # we advertise CAS for backwards compat, though MSC1721 renamed it - # to SSO. - flows.append({"type": LoginRestServlet.CAS_TYPE}) - # The login token flow requires m.login.token to be advertised. support_login_token_flow = self._get_login_token_enabled - if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: + if self.oidc_enabled: flows.append( { "type": LoginRestServlet.SSO_TYPE, @@ -268,7 +259,7 @@ class LoginRestServlet(RestServlet): approval_notice_medium=ApprovalNoticeMedium.NONE, ) - well_known_data = self._well_known_builder.get_well_known() + well_known_data = await self._well_known_builder.get_well_known() if well_known_data: result["well_known"] = well_known_data return 200, result @@ -325,7 +316,7 @@ class LoginRestServlet(RestServlet): *, request_info: RequestInfo, ) -> LoginResponse: - """Handle non-token/saml/jwt logins + """Handle non-token/jwt logins Args: login_submission: @@ -363,6 +354,7 @@ class LoginRestServlet(RestServlet): login_submission: JsonDict, callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, create_non_existent_users: bool = False, + default_display_name: Optional[str] = None, ratelimit: bool = True, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, @@ -410,7 +402,8 @@ class LoginRestServlet(RestServlet): canonical_uid = await self.auth_handler.check_user_exists(user_id) if not canonical_uid: canonical_uid = await self.registration_handler.register_user( - localpart=UserID.from_string(user_id).localpart + localpart=UserID.from_string(user_id).localpart, + default_display_name=default_display_name, ) user_id = canonical_uid @@ -546,11 +539,14 @@ class LoginRestServlet(RestServlet): Returns: The body of the JSON response. """ - user_id = self.hs.get_jwt_handler().validate_login(login_submission) + user_id, default_display_name = self.hs.get_jwt_handler().validate_login( + login_submission + ) return await self._complete_login( user_id, login_submission, create_non_existent_users=True, + default_display_name=default_display_name, should_issue_refresh_token=should_issue_refresh_token, request_info=request_info, ) @@ -622,7 +618,7 @@ class RefreshTokenServlet(RestServlet): class SsoRedirectServlet(RestServlet): - PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ + PATTERNS = list(client_patterns("/login/sso/redirect$", v1=True)) + [ re.compile( "^" + CLIENT_API_PREFIX @@ -679,31 +675,6 @@ class SsoRedirectServlet(RestServlet): finish_request(request) -class CasTicketServlet(RestServlet): - PATTERNS = client_patterns("/login/cas/ticket", v1=True) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self._cas_handler = hs.get_cas_handler() - - async def on_GET(self, request: SynapseRequest) -> None: - client_redirect_url = parse_string(request, "redirectUrl") - ticket = parse_string(request, "ticket", required=True) - - # Maybe get a session ID (if this ticket is from user interactive - # authentication). - session = parse_string(request, "session") - - # Either client_redirect_url or session must be provided. - if not client_redirect_url and not session: - message = "Missing string query parameter redirectUrl or session" - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) - - await self._cas_handler.handle_ticket( - request, ticket, client_redirect_url, session - ) - - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.experimental.msc3861.enabled: return @@ -715,26 +686,18 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ): RefreshTokenServlet(hs).register(http_server) if ( - hs.config.cas.cas_enabled - or hs.config.saml2.saml2_enabled - or hs.config.oidc.oidc_enabled + hs.config.oidc.oidc_enabled ): SsoRedirectServlet(hs).register(http_server) - if hs.config.cas.cas_enabled: - CasTicketServlet(hs).register(http_server) def _load_sso_handlers(hs: "HomeServer") -> None: """Ensure that the SSO handlers are loaded, if they are enabled by configuration. - This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves + This is mostly useful to ensure that the OIDC handler registers itself with the main SsoHandler. It's safe to call this multiple times. """ - if hs.config.cas.cas_enabled: - hs.get_cas_handler() - if hs.config.saml2.saml2_enabled: - hs.get_saml_handler() if hs.config.oidc.oidc_enabled: hs.get_oidc_handler() diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py
index c30e3022de..4c044ae900 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py
@@ -102,10 +102,17 @@ class MediaConfigResource(RestServlet): self.clock = hs.get_clock() self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.media.max_upload_size} + self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository async def on_GET(self, request: SynapseRequest) -> None: - await self.auth.get_user_by_req(request) - respond_with_json(request, 200, self.limits_dict, send_cors=True) + requester = await self.auth.get_user_by_req(request) + user_specific_config = ( + await self.media_repository_callbacks.get_media_config_for_user( + requester.user.to_string(), + ) + ) + response = user_specific_config if user_specific_config else self.limits_dict + respond_with_json(request, 200, response, send_cors=True) class ThumbnailResource(RestServlet): @@ -138,7 +145,7 @@ class ThumbnailResource(RestServlet): ) -> None: # Validate the server name, raising if invalid parse_and_validate_server_name(server_name) - await self.auth.get_user_by_req(request) + await self.auth.get_user_by_req(request, allow_guest=True) set_cors_headers(request) set_corp_headers(request) @@ -229,7 +236,7 @@ class DownloadResource(RestServlet): # Validate the server name, raising if invalid parse_and_validate_server_name(server_name) - await self.auth.get_user_by_req(request) + await self.auth.get_user_by_req(request, allow_guest=True) set_cors_headers(request) set_corp_headers(request) diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py
index 572e92642c..104d54cd89 100644 --- a/synapse/rest/client/presence.py +++ b/synapse/rest/client/presence.py
@@ -19,12 +19,13 @@ # # -""" This module contains REST servlets to do with presence: /presence/<paths> -""" +"""This module contains REST servlets to do with presence: /presence/<paths>""" + import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.presence import format_user_presence_state from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -48,6 +49,14 @@ class PresenceStatusRestServlet(RestServlet): self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() self.auth = hs.get_auth() + self.store = hs.get_datastores().main + + # Ratelimiter for presence updates, keyed by requester. + self._presence_per_user_limiter = Ratelimiter( + store=self.store, + clock=self.clock, + cfg=hs.config.ratelimiting.rc_presence_per_user, + ) async def on_GET( self, request: SynapseRequest, user_id: str @@ -82,6 +91,17 @@ class PresenceStatusRestServlet(RestServlet): if requester.user != user: raise AuthError(403, "Can only set your own presence state") + # ignore the presence update if the ratelimit is exceeded + try: + await self._presence_per_user_limiter.ratelimit(requester) + except LimitExceededError as e: + logger.debug("User presence ratelimit exceeded; ignoring it.") + return 429, { + "errcode": Codes.LIMIT_EXCEEDED, + "error": "Too many requests", + "retry_after_ms": e.retry_after_ms, + } + state = {} content = parse_json_object_from_request(request) diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index c1a80c5c3d..8326d8017c 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py
@@ -19,12 +19,15 @@ # # -""" This module contains REST servlets to do with profile: /profile/<paths> """ +"""This module contains REST servlets to do with profile: /profile/<paths>""" +import re from http import HTTPStatus from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import ProfileFields from synapse.api.errors import Codes, SynapseError +from synapse.handlers.profile import MAX_CUSTOM_FIELD_LEN from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -33,7 +36,8 @@ from synapse.http.servlet import ( ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, JsonValue, UserID +from synapse.util.stringutils import is_namedspaced_grammar if TYPE_CHECKING: from synapse.server import HomeServer @@ -91,6 +95,11 @@ class ProfileDisplaynameRestServlet(RestServlet): async def on_PUT( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: + if not UserID.is_valid(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM + ) + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester) @@ -101,9 +110,7 @@ class ProfileDisplaynameRestServlet(RestServlet): new_name = content["displayname"] except Exception: raise SynapseError( - code=400, - msg="Unable to parse name", - errcode=Codes.BAD_JSON, + 400, "Missing key 'displayname'", errcode=Codes.MISSING_PARAM ) propagate = _read_propagate(self.hs, request) @@ -166,6 +173,11 @@ class ProfileAvatarURLRestServlet(RestServlet): async def on_PUT( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: + if not UserID.is_valid(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM + ) + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester) @@ -227,19 +239,185 @@ class ProfileRestServlet(RestServlet): user = UserID.from_string(user_id) await self.profile_handler.check_profile_query_allowed(user, requester_user) - displayname = await self.profile_handler.get_displayname(user) - avatar_url = await self.profile_handler.get_avatar_url(user) - - ret = {} - if displayname is not None: - ret["displayname"] = displayname - if avatar_url is not None: - ret["avatar_url"] = avatar_url + ret = await self.profile_handler.get_profile(user_id) return 200, ret +class UnstableProfileFieldRestServlet(RestServlet): + PATTERNS = [ + re.compile( + r"^/_matrix/client/unstable/uk\.tcpip\.msc4133/profile/(?P<user_id>[^/]*)/(?P<field_name>[^/]*)" + ) + ] + CATEGORY = "Event sending requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() + + async def on_GET( + self, request: SynapseRequest, user_id: str, field_name: str + ) -> Tuple[int, JsonDict]: + requester_user = None + + if self.hs.config.server.require_auth_for_profile_requests: + requester = await self.auth.get_user_by_req(request) + requester_user = requester.user + + if not UserID.is_valid(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM + ) + + if not field_name: + raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM) + + if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN: + raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE) + if not is_namedspaced_grammar(field_name): + raise SynapseError( + 400, + "Field name does not follow Common Namespaced Identifier Grammar", + errcode=Codes.INVALID_PARAM, + ) + + user = UserID.from_string(user_id) + await self.profile_handler.check_profile_query_allowed(user, requester_user) + + if field_name == ProfileFields.DISPLAYNAME: + field_value: JsonValue = await self.profile_handler.get_displayname(user) + elif field_name == ProfileFields.AVATAR_URL: + field_value = await self.profile_handler.get_avatar_url(user) + else: + field_value = await self.profile_handler.get_profile_field(user, field_name) + + return 200, {field_name: field_value} + + async def on_PUT( + self, request: SynapseRequest, user_id: str, field_name: str + ) -> Tuple[int, JsonDict]: + if not UserID.is_valid(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM + ) + + requester = await self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) + is_admin = await self.auth.is_server_admin(requester) + + if not field_name: + raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM) + + if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN: + raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE) + if not is_namedspaced_grammar(field_name): + raise SynapseError( + 400, + "Field name does not follow Common Namespaced Identifier Grammar", + errcode=Codes.INVALID_PARAM, + ) + + content = parse_json_object_from_request(request) + try: + new_value = content[field_name] + except KeyError: + raise SynapseError( + 400, f"Missing key '{field_name}'", errcode=Codes.MISSING_PARAM + ) + + propagate = _read_propagate(self.hs, request) + + requester_suspended = ( + await self.hs.get_datastores().main.get_user_suspended_status( + requester.user.to_string() + ) + ) + + if requester_suspended: + raise SynapseError( + 403, + "Updating profile while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + + if field_name == ProfileFields.DISPLAYNAME: + await self.profile_handler.set_displayname( + user, requester, new_value, is_admin, propagate=propagate + ) + elif field_name == ProfileFields.AVATAR_URL: + await self.profile_handler.set_avatar_url( + user, requester, new_value, is_admin, propagate=propagate + ) + else: + await self.profile_handler.set_profile_field( + user, requester, field_name, new_value, is_admin + ) + + return 200, {} + + async def on_DELETE( + self, request: SynapseRequest, user_id: str, field_name: str + ) -> Tuple[int, JsonDict]: + if not UserID.is_valid(user_id): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM + ) + + requester = await self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) + is_admin = await self.auth.is_server_admin(requester) + + if not field_name: + raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM) + + if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN: + raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE) + if not is_namedspaced_grammar(field_name): + raise SynapseError( + 400, + "Field name does not follow Common Namespaced Identifier Grammar", + errcode=Codes.INVALID_PARAM, + ) + + propagate = _read_propagate(self.hs, request) + + requester_suspended = ( + await self.hs.get_datastores().main.get_user_suspended_status( + requester.user.to_string() + ) + ) + + if requester_suspended: + raise SynapseError( + 403, + "Updating profile while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + + if field_name == ProfileFields.DISPLAYNAME: + await self.profile_handler.set_displayname( + user, requester, "", is_admin, propagate=propagate + ) + elif field_name == ProfileFields.AVATAR_URL: + await self.profile_handler.set_avatar_url( + user, requester, "", is_admin, propagate=propagate + ) + else: + await self.profile_handler.delete_profile_field( + user, requester, field_name, is_admin + ) + + return 200, {} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + # The specific displayname / avatar URL / custom field endpoints *must* appear + # before their corresponding generic profile endpoint. ProfileDisplaynameRestServlet(hs).register(http_server) ProfileAvatarURLRestServlet(hs).register(http_server) ProfileRestServlet(hs).register(http_server) + if hs.config.experimental.msc4133_enabled: + UnstableProfileFieldRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index a455f95a26..2463b3b38c 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py
@@ -34,7 +34,6 @@ from synapse.http.site import SynapseRequest from synapse.push import PusherConfigException from synapse.rest.admin.experimental_features import ExperimentalFeature from synapse.rest.client._base import client_patterns -from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource from synapse.types import JsonDict if TYPE_CHECKING: @@ -161,21 +160,6 @@ class PushersSetRestServlet(RestServlet): return 200, {} -class LegacyPushersRemoveRestServlet(UnsubscribeResource, RestServlet): - """ - A servlet to handle legacy "email unsubscribe" links, forwarding requests to the ``UnsubscribeResource`` - - This should be kept for some time, so unsubscribe links in past emails stay valid. - """ - - PATTERNS = client_patterns("/pushers/remove$", releases=[], v1=False, unstable=True) - - async def on_GET(self, request: SynapseRequest) -> None: - # Forward the request to the UnsubscribeResource - await self._async_render(request) - - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PushersRestServlet(hs).register(http_server) PushersSetRestServlet(hs).register(http_server) - LegacyPushersRemoveRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 89203dc45a..4bf93f485c 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py
@@ -39,9 +39,7 @@ logger = logging.getLogger(__name__) class ReceiptRestServlet(RestServlet): PATTERNS = client_patterns( - "/rooms/(?P<room_id>[^/]*)" - "/receipt/(?P<receipt_type>[^/]*)" - "/(?P<event_id>[^/]*)$" + "/rooms/(?P<room_id>[^/]*)/receipt/(?P<receipt_type>[^/]*)/(?P<event_id>[^/]*)$" ) CATEGORY = "Receipts requests" diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 5dddbc69be..9d18d8ba25 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py
@@ -38,14 +38,12 @@ from synapse.api.errors import ( InteractiveAuthIncompleteError, NotApprovedError, SynapseError, - ThreepidValidationError, UnrecognizedRequestError, ) from synapse.api.ratelimiting import Ratelimiter from synapse.config import ConfigError from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRatelimitSettings -from synapse.config.server import is_threepid_reserved from synapse.handlers.auth import AuthHandler from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http.server import HttpServer, finish_request, respond_with_html @@ -56,17 +54,9 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest -from synapse.metrics import threepid_send_requests -from synapse.push.mailer import Mailer from synapse.types import JsonDict -from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import ( - canonicalise_email, - check_3pid_allowed, - validate_email, -) from ._base import client_patterns, interactive_auth_handler @@ -76,247 +66,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class EmailRegisterRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/register/email/requestToken$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.config = hs.config - - if self.hs.config.email.can_verify_email: - self.registration_mailer = Mailer( - hs=self.hs, - app_name=self.config.email.email_app_name, - template_html=self.config.email.email_registration_template_html, - template_text=self.config.email.email_registration_template_text, - ) - self.already_in_use_mailer = Mailer( - hs=self.hs, - app_name=self.config.email.email_app_name, - template_html=self.config.email.email_already_in_use_template_html, - template_text=self.config.email.email_already_in_use_template_text, - ) - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.email.can_verify_email: - logger.warning( - "Email registration has been disabled due to lack of email config" - ) - raise SynapseError( - 400, "Email-based registration has been disabled on this server" - ) - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) - - # Extract params from body - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See on_POST in EmailThreepidRequestTokenRestServlet - # in synapse/rest/client/account.py) - try: - email = validate_email(body["email"]) - except ValueError as e: - raise SynapseError(400, str(e)) - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - if not await check_3pid_allowed(self.hs, "email", email, registration=True): - raise SynapseError( - 403, - "Your email domain is not authorized to register on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "email", email - ) - - existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( - "email", email - ) - - if existing_user_id is not None: - if self.hs.config.server.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Still send an email to warn the user that an account already exists. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.already_in_use_mailer.send_already_in_use_mail(email) - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - - # Send registration emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, - self.registration_mailer.send_registration_mail, - next_link, - ) - - threepid_send_requests.labels(type="email", reason="register").observe( - send_attempt - ) - - # Wrap the session id in a JSON object - return 200, {"sid": sid} - - -class MsisdnRegisterRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/register/msisdn/requestToken$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - body = parse_json_object_from_request(request) - - assert_params_in_dict( - body, ["client_secret", "country", "phone_number", "send_attempt"] - ) - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - country = body["country"] - phone_number = body["phone_number"] - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - msisdn = phone_number_to_msisdn(country, phone_number) - - if not await check_3pid_allowed(self.hs, "msisdn", msisdn, registration=True): - raise SynapseError( - 403, - "Phone numbers are not authorized to register on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "msisdn", msisdn - ) - - existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( - "msisdn", msisdn - ) - - if existing_user_id is not None: - if self.hs.config.server.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError( - 400, "Phone number is already in use", Codes.THREEPID_IN_USE - ) - - if not self.hs.config.registration.account_threepid_delegate_msisdn: - logger.warning( - "No upstream msisdn account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, "Registration by phone number is not supported on this homeserver" - ) - - ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.registration.account_threepid_delegate_msisdn, - country, - phone_number, - client_secret, - send_attempt, - next_link, - ) - - threepid_send_requests.labels(type="msisdn", reason="register").observe( - send_attempt - ) - - return 200, ret - - -class RegistrationSubmitTokenServlet(RestServlet): - """Handles registration 3PID validation token submission""" - - PATTERNS = client_patterns( - "/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - - if self.config.email.can_verify_email: - self._failure_email_template = ( - self.config.email.email_registration_template_failure_html - ) - - async def on_GET(self, request: Request, medium: str) -> None: - if medium != "email": - raise SynapseError( - 400, "This medium is currently not supported for registration" - ) - if not self.config.email.can_verify_email: - logger.warning( - "User registration via email has been disabled due to lack of email config" - ) - raise SynapseError( - 400, "Email-based registration is disabled on this server" - ) - - sid = parse_string(request, "sid", required=True) - client_secret = parse_string(request, "client_secret", required=True) - assert_valid_client_secret(client_secret) - token = parse_string(request, "token", required=True) - - # Attempt to validate a 3PID session - try: - # Mark the session as valid - next_link = await self.store.validate_threepid_session( - sid, client_secret, token, self.clock.time_msec() - ) - - # Perform a 302 redirect if next_link is set - if next_link: - if next_link.startswith("file:///"): - logger.warning( - "Not redirecting to next_link as it is a local file: address" - ) - else: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None - - # Otherwise show the success template - html = self.config.email.email_registration_template_success_html_content - status_code = 200 - except ThreepidValidationError as e: - status_code = e.code - - # Show a failure page with a reason - template_vars = {"failure_reason": e.msg} - html = self._failure_email_template.render(**template_vars) - - respond_with_html(request, status_code, html) - - class UsernameAvailabilityRestServlet(RestServlet): PATTERNS = client_patterns("/register/available") @@ -420,7 +169,6 @@ class RegisterRestServlet(RestServlet): self.store = hs.get_datastores().main self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - self.identity_handler = hs.get_identity_handler() self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() self.ratelimiter = hs.get_registration_ratelimiter() @@ -605,27 +353,6 @@ class RegisterRestServlet(RestServlet): ) raise - # Check that we're not trying to register a denied 3pid. - # - # the user-facing checks will probably already have happened in - # /register/email/requestToken when we requested a 3pid, but that's not - # guaranteed. - if auth_result: - for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: - if login_type in auth_result: - medium = auth_result[login_type]["medium"] - address = auth_result[login_type]["address"] - - if not await check_3pid_allowed( - self.hs, medium, address, registration=True - ): - raise SynapseError( - 403, - "Third party identifiers (email/phone numbers)" - + " are not authorized on this server", - Codes.THREEPID_DENIED, - ) - if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", registered_user_id @@ -640,12 +367,10 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = ( - await ( - self.password_auth_provider.get_username_for_registration( - auth_result, - params, - ) + desired_username = await ( + self.password_auth_provider.get_username_for_registration( + auth_result, + params, ) ) @@ -657,50 +382,13 @@ class RegisterRestServlet(RestServlet): if desired_username is not None: desired_username = desired_username.lower() - threepid = None - if auth_result: - threepid = auth_result.get(LoginType.EMAIL_IDENTITY) - - # Also check that we're not trying to register a 3pid that's already - # been registered. - # - # This has probably happened in /register/email/requestToken as well, - # but if a user hits this endpoint twice then clicks on each link from - # the two activation emails, they would register the same 3pid twice. - for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: - if login_type in auth_result: - medium = auth_result[login_type]["medium"] - address = auth_result[login_type]["address"] - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See on_POST in EmailThreepidRequestTokenRestServlet - # in synapse/rest/client/account.py) - if medium == "email": - try: - address = canonicalise_email(address) - except ValueError as e: - raise SynapseError(400, str(e)) - - existing_user_id = await self.store.get_user_id_by_threepid( - medium, address - ) - - if existing_user_id is not None: - raise SynapseError( - 400, - "%s is already in use" % medium, - Codes.THREEPID_IN_USE, - ) - entries = await self.store.get_user_agents_ips_to_ui_auth_session( session_id ) - display_name = ( - await ( - self.password_auth_provider.get_displayname_for_registration( - auth_result, params - ) + display_name = await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params ) ) @@ -708,18 +396,10 @@ class RegisterRestServlet(RestServlet): localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, - threepid=threepid, default_display_name=display_name, address=client_addr, user_agent_ips=entries, ) - # Necessary due to auth checks prior to the threepid being - # written to the db - if threepid: - if is_threepid_reserved( - self.hs.config.server.mau_limits_reserved_threepids, threepid - ): - await self.store.upsert_monthly_active_user(registered_user_id) # Remember that the user account has been registered (and the user # ID it was registered with, since it might not have been specified). @@ -775,9 +455,12 @@ class RegisterRestServlet(RestServlet): body: JsonDict, should_issue_refresh_token: bool = False, ) -> JsonDict: - user_id = await self.registration_handler.appservice_register( + user_id, appservice = await self.registration_handler.appservice_register( username, as_token ) + if appservice.msc4190_device_management: + body["inhibit_login"] = True + return await self._create_registration_details( user_id, body, @@ -909,6 +592,14 @@ class RegisterAppServiceOnlyRestServlet(RestServlet): await self.ratelimiter.ratelimit(None, client_addr, update=False) + # Allow only ASes to use this API. + if body.get("type") != APP_SERVICE_REGISTRATION_TYPE: + raise SynapseError( + 403, + "Registration has been disabled. Only m.login.application_service registrations are allowed.", + errcode=Codes.FORBIDDEN, + ) + kind = parse_string(request, "kind", default="user") if kind == "guest": @@ -924,10 +615,6 @@ class RegisterAppServiceOnlyRestServlet(RestServlet): if not isinstance(desired_username, str) or len(desired_username) > 512: raise SynapseError(400, "Invalid username") - # Allow only ASes to use this API. - if body.get("type") != APP_SERVICE_REGISTRATION_TYPE: - raise SynapseError(403, "Non-application service registration type") - if not self.auth.has_access_token(request): raise SynapseError( 400, @@ -941,7 +628,7 @@ class RegisterAppServiceOnlyRestServlet(RestServlet): as_token = self.auth.get_access_token_from_request(request) - user_id = await self.registration_handler.appservice_register( + user_id, _ = await self.registration_handler.appservice_register( desired_username, as_token ) return 200, {"user_id": user_id} @@ -958,60 +645,11 @@ def _calculate_registration_flows( Returns: a list of supported flows """ - # FIXME: need a better error than "no auth flow found" for scenarios - # where we required 3PID for registration but the user didn't give one - require_email = "email" in config.registration.registrations_require_3pid - require_msisdn = "msisdn" in config.registration.registrations_require_3pid - - show_msisdn = True - show_email = True - - if config.registration.disable_msisdn_registration: - show_msisdn = False - require_msisdn = False - enabled_auth_types = auth_handler.get_enabled_auth_types() - if LoginType.EMAIL_IDENTITY not in enabled_auth_types: - show_email = False - if require_email: - raise ConfigError( - "Configuration requires email address at registration, but email " - "validation is not configured" - ) - - if LoginType.MSISDN not in enabled_auth_types: - show_msisdn = False - if require_msisdn: - raise ConfigError( - "Configuration requires msisdn at registration, but msisdn " - "validation is not configured" - ) - flows = [] - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - # Add a dummy step here, otherwise if a client completes - # recaptcha first we'll assume they were going for this flow - # and complete the request, when they could have been trying to - # complete one of the flows with email/msisdn auth. - flows.append([LoginType.DUMMY]) - - # only support the email-only flow if we don't require MSISDN 3PIDs - if show_email and not require_msisdn: - flows.append([LoginType.EMAIL_IDENTITY]) - - # only support the MSISDN-only flow if we don't require email 3PIDs - if show_msisdn and not require_email: - flows.append([LoginType.MSISDN]) - - if show_email and show_msisdn: - # always let users provide both MSISDN & email - flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) - - # Add a flow that doesn't require any 3pids, if the config requests it. - if config.registration.enable_registration_token_3pid_bypass: - flows.append([LoginType.REGISTRATION_TOKEN]) + # We don't support 3PIDs + flows.append([LoginType.DUMMY]) # Prepend m.login.terms to all flows if we're requiring consent if config.consent.user_consent_at_registration: @@ -1037,10 +675,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RegisterAppServiceOnlyRestServlet(hs).register(http_server) return - if hs.config.worker.worker_app is None: - EmailRegisterRequestTokenRestServlet(hs).register(http_server) - MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) - RegistrationSubmitTokenServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server) RegistrationTokenValidityRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py
index 27bf53314a..a1808847f0 100644 --- a/synapse/rest/client/rendezvous.py +++ b/synapse/rest/client/rendezvous.py
@@ -34,51 +34,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# n.b [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) has now been closed. -# However, we want to keep this implementation around for some time. -# TODO: define an end-of-life date for this implementation. -class MSC3886RendezvousServlet(RestServlet): - """ - This is a placeholder implementation of [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) - simple client rendezvous capability that is used by the "Sign in with QR" functionality. - - This implementation only serves as a 307 redirect to a configured server rather than being a full implementation. - - A module that implements the full functionality is available at: https://pypi.org/project/matrix-http-rendezvous-synapse/. - - Request: - - POST /rendezvous HTTP/1.1 - Content-Type: ... - - ... - - Response: - - HTTP/1.1 307 - Location: <configured endpoint> - """ - - PATTERNS = client_patterns( - "/org.matrix.msc3886/rendezvous$", releases=[], v1=False, unstable=True - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - redirection_target: Optional[str] = hs.config.experimental.msc3886_endpoint - assert ( - redirection_target is not None - ), "Servlet is only registered if there is a redirection target" - self.endpoint = redirection_target.encode("utf-8") - - async def on_POST(self, request: SynapseRequest) -> None: - respond_with_redirect( - request, self.endpoint, statusCode=TEMPORARY_REDIRECT, cors=True - ) - - # PUT, GET and DELETE are not implemented as they should be fulfilled by the redirect target. - - class MSC4108DelegationRendezvousServlet(RestServlet): PATTERNS = client_patterns( "/org.matrix.msc4108/rendezvous$", releases=[], v1=False, unstable=True @@ -89,9 +44,9 @@ class MSC4108DelegationRendezvousServlet(RestServlet): redirection_target: Optional[str] = ( hs.config.experimental.msc4108_delegation_endpoint ) - assert ( - redirection_target is not None - ), "Servlet is only registered if there is a delegation target" + assert redirection_target is not None, ( + "Servlet is only registered if there is a delegation target" + ) self.endpoint = redirection_target.encode("utf-8") async def on_POST(self, request: SynapseRequest) -> None: @@ -114,9 +69,6 @@ class MSC4108RendezvousServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - if hs.config.experimental.msc3886_endpoint is not None: - MSC3886RendezvousServlet(hs).register(http_server) - if hs.config.experimental.msc4108_enabled: MSC4108RendezvousServlet(hs).register(http_server) diff --git a/synapse/rest/client/reporting.py b/synapse/rest/client/reporting.py
index 4eee53e5a8..c5037be8b7 100644 --- a/synapse/rest/client/reporting.py +++ b/synapse/rest/client/reporting.py
@@ -23,7 +23,7 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from synapse._pydantic_compat import HAS_PYDANTIC_V2 +from synapse._pydantic_compat import StrictStr from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import ( @@ -40,10 +40,6 @@ from ._base import client_patterns if TYPE_CHECKING: from synapse.server import HomeServer -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import StrictStr -else: - from pydantic import StrictStr logger = logging.getLogger(__name__) @@ -109,18 +105,17 @@ class ReportEventRestServlet(RestServlet): class ReportRoomRestServlet(RestServlet): """This endpoint lets clients report a room for abuse. - Whilst MSC4151 is not yet merged, this unstable endpoint is enabled on matrix.org - for content moderation purposes, and therefore backwards compatibility should be - carefully considered when changing anything on this endpoint. - - More details on the MSC: https://github.com/matrix-org/matrix-spec-proposals/pull/4151 + Introduced by MSC4151: https://github.com/matrix-org/matrix-spec-proposals/pull/4151 """ - PATTERNS = client_patterns( - "/org.matrix.msc4151/rooms/(?P<room_id>[^/]*)/report$", - releases=[], - v1=False, - unstable=True, + # Cast the Iterable to a list so that we can `append` below. + PATTERNS = list( + client_patterns( + "/rooms/(?P<room_id>[^/]*)/report$", + releases=("v3",), + unstable=False, + v1=False, + ) ) def __init__(self, hs: "HomeServer"): @@ -157,6 +152,4 @@ class ReportRoomRestServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReportEventRestServlet(hs).register(http_server) - - if hs.config.experimental.msc4151_enabled: - ReportRoomRestServlet(hs).register(http_server) + ReportRoomRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 903c74f6d8..38230de0de 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py
@@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2014-2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -19,7 +19,8 @@ # # -""" This module contains REST servlets to do with rooms: /rooms/<paths> """ +"""This module contains REST servlets to do with rooms: /rooms/<paths>""" + import logging import re from enum import Enum @@ -67,7 +68,8 @@ from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID from synapse.types.state import StateFilter from synapse.util.cancellation import cancellable -from synapse.util.stringutils import parse_and_validate_server_name, random_string +from synapse.util.events import generate_fake_event_id +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.server import HomeServer @@ -193,7 +195,10 @@ class RoomStateEventRestServlet(RestServlet): self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() + self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() + self._max_event_delay_ms = hs.config.server.max_event_delay_ms + self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/state/$eventtype @@ -285,10 +290,45 @@ class RoomStateEventRestServlet(RestServlet): content = parse_json_object_from_request(request) + is_requester_admin = await self.auth.is_server_admin(requester) + if not is_requester_admin: + spam_check = ( + await self._spam_checker_module_callbacks.user_may_send_state_event( + user_id=requester.user.to_string(), + room_id=room_id, + event_type=event_type, + state_key=state_key, + content=content, + ) + ) + if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: + raise SynapseError( + 403, + "You are not permitted to send the state event", + errcode=spam_check[0], + additional_fields=spam_check[1], + ) + origin_server_ts = None if requester.app_service: origin_server_ts = parse_integer(request, "ts") + delay = _parse_request_delay(request, self._max_event_delay_ms) + if delay is not None: + delay_id = await self.delayed_events_handler.add( + requester, + room_id=room_id, + event_type=event_type, + state_key=state_key, + origin_server_ts=origin_server_ts, + content=content, + delay=delay, + ) + + set_tag("delay_id", delay_id) + ret = {"delay_id": delay_id} + return 200, ret + try: if event_type == EventTypes.Member: membership = content.get("membership", None) @@ -325,7 +365,7 @@ class RoomStateEventRestServlet(RestServlet): ) event_id = event.event_id except ShadowBanError: - event_id = "$" + random_string(43) + event_id = generate_fake_event_id() set_tag("event_id", event_id) ret = {"event_id": event_id} @@ -339,7 +379,9 @@ class RoomSendEventRestServlet(TransactionRestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() + self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() + self._max_event_delay_ms = hs.config.server.max_event_delay_ms def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] @@ -356,6 +398,26 @@ class RoomSendEventRestServlet(TransactionRestServlet): ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) + origin_server_ts = None + if requester.app_service: + origin_server_ts = parse_integer(request, "ts") + + delay = _parse_request_delay(request, self._max_event_delay_ms) + if delay is not None: + delay_id = await self.delayed_events_handler.add( + requester, + room_id=room_id, + event_type=event_type, + state_key=None, + origin_server_ts=origin_server_ts, + content=content, + delay=delay, + ) + + set_tag("delay_id", delay_id) + ret = {"delay_id": delay_id} + return 200, ret + event_dict: JsonDict = { "type": event_type, "content": content, @@ -363,10 +425,8 @@ class RoomSendEventRestServlet(TransactionRestServlet): "sender": requester.user.to_string(), } - if requester.app_service: - origin_server_ts = parse_integer(request, "ts") - if origin_server_ts is not None: - event_dict["origin_server_ts"] = origin_server_ts + if origin_server_ts is not None: + event_dict["origin_server_ts"] = origin_server_ts try: ( @@ -377,7 +437,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): ) event_id = event.event_id except ShadowBanError: - event_id = "$" + random_string(43) + event_id = generate_fake_event_id() set_tag("event_id", event_id) return 200, {"event_id": event_id} @@ -409,6 +469,49 @@ class RoomSendEventRestServlet(TransactionRestServlet): ) +def _parse_request_delay( + request: SynapseRequest, + max_delay: Optional[int], +) -> Optional[int]: + """Parses from the request string the delay parameter for + delayed event requests, and checks it for correctness. + + Args: + request: the twisted HTTP request. + max_delay: the maximum allowed value of the delay parameter, + or None if no delay parameter is allowed. + Returns: + The value of the requested delay, or None if it was absent. + + Raises: + SynapseError: if the delay parameter is present and forbidden, + or if it exceeds the maximum allowed value. + """ + delay = parse_integer(request, "org.matrix.msc4140.delay") + if delay is None: + return None + if max_delay is None: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Delayed events are not supported on this server", + Codes.UNKNOWN, + { + "org.matrix.msc4140.errcode": "M_MAX_DELAY_UNSUPPORTED", + }, + ) + if delay > max_delay: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The requested delay exceeds the allowed maximum.", + Codes.UNKNOWN, + { + "org.matrix.msc4140.errcode": "M_MAX_DELAY_EXCEEDED", + "org.matrix.msc4140.max_delay": max_delay, + }, + ) + return delay + + # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): CATEGORY = "Event sending requests" @@ -417,7 +520,6 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): super().__init__(hs) super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up self.auth = hs.get_auth() - self._support_via = hs.config.experimental.msc4156_enabled def register(self, http_server: HttpServer) -> None: # /join/$room_identifier[/$txn_id] @@ -435,13 +537,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): # twisted.web.server.Request.args is incorrectly defined as Optional[Any] args: Dict[bytes, List[bytes]] = request.args # type: ignore - remote_room_hosts = parse_strings_from_args(args, "server_name", required=False) - if self._support_via: + # Prefer via over server_name (deprecated with MSC4156) + remote_room_hosts = parse_strings_from_args(args, "via", required=False) + if remote_room_hosts is None: remote_room_hosts = parse_strings_from_args( - args, - "org.matrix.msc4156.via", - default=remote_room_hosts, - required=False, + args, "server_name", required=False ) room_id, remote_room_hosts = await self.resolve_room_id( room_identifier, @@ -703,9 +803,9 @@ class RoomMessageListRestServlet(RestServlet): # decorator on `get_number_joined_users_in_room` doesn't play well with # the type system. Maybe in the future, it can use some ParamSpec # wizardry to fix it up. - room_member_count_deferred = run_in_background( # type: ignore[call-arg] + room_member_count_deferred = run_in_background( # type: ignore[call-overload] self.store.get_number_joined_users_in_room, - room_id, # type: ignore[arg-type] + room_id, ) requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -814,12 +914,10 @@ class RoomEventServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) include_unredacted_content = self.msc2815_enabled and ( - parse_string( + parse_boolean( request, - "fi.mau.msc2815.include_unredacted_content", - allowed_values=("true", "false"), + "fi.mau.msc2815.include_unredacted_content" ) - == "true" ) if include_unredacted_content and not await self.auth.is_server_admin( requester @@ -1193,7 +1291,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): event_id = event.event_id except ShadowBanError: - event_id = "$" + random_string(43) + event_id = generate_fake_event_id() set_tag("event_id", event_id) return 200, {"event_id": event_id} diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 8c5db2a513..bac02122d0 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py
@@ -21,12 +21,13 @@ import itertools import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState +from synapse.api.ratelimiting import Ratelimiter from synapse.events.utils import ( SerializeEventConfig, format_event_for_client_v2_without_room_id, @@ -126,6 +127,13 @@ class SyncRestServlet(RestServlet): cache_name="sync_valid_filter", ) + # Ratelimiter for presence updates, keyed by requester. + self._presence_per_user_limiter = Ratelimiter( + store=self.store, + clock=self.clock, + cfg=hs.config.ratelimiting.rc_presence_per_user, + ) + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: # This will always be set by the time Twisted calls us. assert request.args is not None @@ -152,6 +160,14 @@ class SyncRestServlet(RestServlet): filter_id = parse_string(request, "filter") full_state = parse_boolean(request, "full_state", default=False) + use_state_after = False + if await self.store.is_feature_enabled( + user.to_string(), ExperimentalFeature.MSC4222 + ): + use_state_after = parse_boolean( + request, "org.matrix.msc4222.use_state_after", default=False + ) + logger.debug( "/sync: user=%r, timeout=%r, since=%r, " "set_presence=%r, filter_id=%r, device_id=%r", @@ -184,6 +200,7 @@ class SyncRestServlet(RestServlet): full_state, device_id, last_ignore_accdata_streampos, + use_state_after, ) if filter_id is None: @@ -220,6 +237,7 @@ class SyncRestServlet(RestServlet): filter_collection=filter_collection, is_guest=requester.is_guest, device_id=device_id, + use_state_after=use_state_after, ) since_token = None @@ -229,7 +247,13 @@ class SyncRestServlet(RestServlet): # send any outstanding server notices to the user. await self._server_notices_sender.on_user_syncing(user.to_string()) - affect_presence = set_presence != PresenceState.OFFLINE + # ignore the presence update if the ratelimit is exceeded but do not pause the request + allowed, _ = await self._presence_per_user_limiter.can_do_action(requester) + if not allowed: + affect_presence = False + logger.debug("User set_presence ratelimit exceeded; ignoring it.") + else: + affect_presence = set_presence != PresenceState.OFFLINE context = await self.presence_handler.user_syncing( user.to_string(), @@ -258,7 +282,7 @@ class SyncRestServlet(RestServlet): # We know that the the requester has an access token since appservices # cannot use sync. response_content = await self.encode_response( - time_now, sync_result, requester, filter_collection + time_now, sync_config, sync_result, requester, filter_collection ) logger.debug("Event formatting complete") @@ -268,6 +292,7 @@ class SyncRestServlet(RestServlet): async def encode_response( self, time_now: int, + sync_config: SyncConfig, sync_result: SyncResult, requester: Requester, filter: FilterCollection, @@ -292,7 +317,7 @@ class SyncRestServlet(RestServlet): ) joined = await self.encode_joined( - sync_result.joined, time_now, serialize_options + sync_config, sync_result.joined, time_now, serialize_options ) invited = await self.encode_invited( @@ -304,7 +329,7 @@ class SyncRestServlet(RestServlet): ) archived = await self.encode_archived( - sync_result.archived, time_now, serialize_options + sync_config, sync_result.archived, time_now, serialize_options ) logger.debug("building sync response dict") @@ -372,6 +397,7 @@ class SyncRestServlet(RestServlet): @trace_with_opname("sync.encode_joined") async def encode_joined( self, + sync_config: SyncConfig, rooms: List[JoinedSyncResult], time_now: int, serialize_options: SerializeEventConfig, @@ -380,6 +406,7 @@ class SyncRestServlet(RestServlet): Encode the joined rooms in a sync result Args: + sync_config rooms: list of sync results for rooms this user is joined to time_now: current time - used as a baseline for age calculations serialize_options: Event serializer options @@ -389,7 +416,11 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = await self.encode_room( - room, time_now, joined=True, serialize_options=serialize_options + sync_config, + room, + time_now, + joined=True, + serialize_options=serialize_options, ) return joined @@ -419,7 +450,12 @@ class SyncRestServlet(RestServlet): ) unsigned = dict(invite.get("unsigned", {})) invite["unsigned"] = unsigned - invited_state = list(unsigned.pop("invite_room_state", [])) + + invited_state = unsigned.pop("invite_room_state", []) + if not isinstance(invited_state, list): + invited_state = [] + + invited_state = list(invited_state) invited_state.append(invite) invited[room.room_id] = {"invite_state": {"events": invited_state}} @@ -459,7 +495,10 @@ class SyncRestServlet(RestServlet): # Extract the stripped room state from the unsigned dict # This is for clients to get a little bit of information about # the room they've knocked on, without revealing any sensitive information - knocked_state = list(unsigned.pop("knock_room_state", [])) + knocked_state = unsigned.pop("knock_room_state", []) + if not isinstance(knocked_state, list): + knocked_state = [] + knocked_state = list(knocked_state) # Append the actual knock membership event itself as well. This provides # the client with: @@ -477,6 +516,7 @@ class SyncRestServlet(RestServlet): @trace_with_opname("sync.encode_archived") async def encode_archived( self, + sync_config: SyncConfig, rooms: List[ArchivedSyncResult], time_now: int, serialize_options: SerializeEventConfig, @@ -485,6 +525,7 @@ class SyncRestServlet(RestServlet): Encode the archived rooms in a sync result Args: + sync_config rooms: list of sync results for rooms this user is joined to time_now: current time - used as a baseline for age calculations serialize_options: Event serializer options @@ -494,13 +535,18 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = await self.encode_room( - room, time_now, joined=False, serialize_options=serialize_options + sync_config, + room, + time_now, + joined=False, + serialize_options=serialize_options, ) return joined async def encode_room( self, + sync_config: SyncConfig, room: Union[JoinedSyncResult, ArchivedSyncResult], time_now: int, joined: bool, @@ -508,6 +554,7 @@ class SyncRestServlet(RestServlet): ) -> JsonDict: """ Args: + sync_config room: sync result for a single room time_now: current time - used as a baseline for age calculations token_id: ID of the user's auth token - used for namespacing @@ -548,13 +595,20 @@ class SyncRestServlet(RestServlet): account_data = room.account_data + # We either include a `state` or `state_after` field depending on + # whether the client has opted in to the newer `state_after` behavior. + if sync_config.use_state_after: + state_key_name = "org.matrix.msc4222.state_after" + else: + state_key_name = "state" + result: JsonDict = { "timeline": { "events": serialized_timeline, "prev_batch": await room.timeline.prev_batch.to_string(self.store), "limited": room.timeline.limited, }, - "state": {"events": serialized_state}, + state_key_name: {"events": serialized_state}, "account_data": {"events": account_data}, } @@ -688,6 +742,7 @@ class SlidingSyncE2eeRestServlet(RestServlet): filter_collection=self.only_member_events_filter_collection, is_guest=requester.is_guest, device_id=device_id, + use_state_after=False, # We don't return any rooms so this flag is a no-op ) since_token = None @@ -975,7 +1030,7 @@ class SlidingSyncRestServlet(RestServlet): return response def encode_lists( - self, lists: Dict[str, SlidingSyncResult.SlidingWindowList] + self, lists: Mapping[str, SlidingSyncResult.SlidingWindowList] ) -> JsonDict: def encode_operation( operation: SlidingSyncResult.SlidingWindowList.Operation, @@ -1010,13 +1065,19 @@ class SlidingSyncRestServlet(RestServlet): serialized_rooms: Dict[str, JsonDict] = {} for room_id, room_result in rooms.items(): serialized_rooms[room_id] = { - "bump_stamp": room_result.bump_stamp, - "joined_count": room_result.joined_count, - "invited_count": room_result.invited_count, "notification_count": room_result.notification_count, "highlight_count": room_result.highlight_count, } + if room_result.bump_stamp is not None: + serialized_rooms[room_id]["bump_stamp"] = room_result.bump_stamp + + if room_result.joined_count is not None: + serialized_rooms[room_id]["joined_count"] = room_result.joined_count + + if room_result.invited_count is not None: + serialized_rooms[room_id]["invited_count"] = room_result.invited_count + if room_result.name: serialized_rooms[room_id]["name"] = room_result.name @@ -1040,10 +1101,15 @@ class SlidingSyncRestServlet(RestServlet): serialized_rooms[room_id]["heroes"] = serialized_heroes # We should only include the `initial` key if it's `True` to save bandwidth. - # The absense of this flag means `False`. + # The absence of this flag means `False`. if room_result.initial: serialized_rooms[room_id]["initial"] = room_result.initial + if room_result.unstable_expanded_timeline: + serialized_rooms[room_id]["unstable_expanded_timeline"] = ( + room_result.unstable_expanded_timeline + ) + # This will be omitted for invite/knock rooms with `stripped_state` if ( room_result.required_state is not None @@ -1077,9 +1143,9 @@ class SlidingSyncRestServlet(RestServlet): # This will be omitted for invite/knock rooms with `stripped_state` if room_result.prev_batch is not None: - serialized_rooms[room_id]["prev_batch"] = ( - await room_result.prev_batch.to_string(self.store) - ) + serialized_rooms[room_id][ + "prev_batch" + ] = await room_result.prev_batch.to_string(self.store) # This will be omitted for invite/knock rooms with `stripped_state` if room_result.num_live is not None: diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index 554bcb95dd..b6648f3499 100644 --- a/synapse/rest/client/tags.py +++ b/synapse/rest/client/tags.py
@@ -78,6 +78,7 @@ class TagServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.handler = hs.get_account_data_handler() + self.room_member_handler = hs.get_room_member_handler() async def on_PUT( self, request: SynapseRequest, user_id: str, room_id: str, tag: str @@ -85,6 +86,12 @@ class TagServlet(RestServlet): requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") + # Check if the user has any membership in the room and raise error if not. + # Although it's not harmful for users to tag random rooms, it's just superfluous + # data we don't need to track or allow. + await self.room_member_handler.check_for_any_membership_in_room( + user_id=user_id, room_id=room_id + ) body = parse_json_object_from_request(request) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 30c1f17fc6..1a57996aec 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py
@@ -21,6 +21,7 @@ """This module contains logic for storing HTTP PUT transactions. This is used to ensure idempotency when performing PUTs using the REST API.""" + import logging from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple @@ -93,9 +94,9 @@ class HttpTransactionCache: # (appservice and guest users), but does not cover access tokens minted # by the admin API. Use the access token ID instead. else: - assert ( - requester.access_token_id is not None - ), "Requester must have an access_token_id" + assert requester.access_token_id is not None, ( + "Requester must have an access_token_id" + ) return (path, "user_admin", requester.access_token_id) def fetch_or_execute_request( diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 75df684416..f58f11e5cc 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py
@@ -4,7 +4,7 @@ # Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2017 Vector Creations Ltd # Copyright 2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -64,6 +64,7 @@ class VersionsRestServlet(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: msc3881_enabled = self.config.experimental.msc3881_enabled + msc3575_enabled = self.config.experimental.msc3575_enabled if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req( @@ -77,6 +78,9 @@ class VersionsRestServlet(RestServlet): msc3881_enabled = await self.store.is_feature_enabled( user_id, ExperimentalFeature.MSC3881 ) + msc3575_enabled = await self.store.is_feature_enabled( + user_id, ExperimentalFeature.MSC3575 + ) return ( 200, @@ -145,9 +149,6 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3881": msc3881_enabled, # Adds support for filtering /messages by event relation. "org.matrix.msc3874": self.config.experimental.msc3874_enabled, - # Adds support for simple HTTP rendezvous as per MSC3886 - "org.matrix.msc3886": self.config.experimental.msc3886_endpoint - is not None, # Adds support for relation-based redactions as per MSC3912. "org.matrix.msc3912": self.config.experimental.msc3912_enabled, # Whether recursively provide relations is supported. @@ -167,8 +168,14 @@ class VersionsRestServlet(RestServlet): is not None ) ), - # MSC4151: Report room API (Client-Server API) - "org.matrix.msc4151": self.config.experimental.msc4151_enabled, + # MSC4140: Delayed events + "org.matrix.msc4140": bool(self.config.server.max_event_delay_ms), + # Simplified sliding sync + "org.matrix.simplified_msc3575": msc3575_enabled, + # Arbitrary key-value profile fields. + "uk.tcpip.msc4133": self.config.experimental.msc4133_enabled, + # MSC4155: Invite filtering + "org.matrix.msc4155": self.config.experimental.msc4155_enabled, }, }, ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index a411ed614e..fea0b9706d 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -23,17 +23,11 @@ import logging import re from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple -from synapse._pydantic_compat import HAS_PYDANTIC_V2 - -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import Extra, StrictInt, StrictStr -else: - from pydantic import StrictInt, StrictStr, Extra - from signedjson.sign import sign_json from twisted.web.server import Request +from synapse._pydantic_compat import Extra, StrictInt, StrictStr from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import HttpServer from synapse.http.servlet import ( @@ -191,10 +185,10 @@ class RemoteKey(RestServlet): server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} for server_name, key_ids in query.items(): if key_ids: - results: Mapping[str, Optional[FetchKeyResultForRemote]] = ( - await self.store.get_server_keys_json_for_remote( - server_name, key_ids - ) + results: Mapping[ + str, Optional[FetchKeyResultForRemote] + ] = await self.store.get_server_keys_json_for_remote( + server_name, key_ids ) else: results = await self.store.get_all_server_keys_json_for_remote( diff --git a/synapse/rest/media/config_resource.py b/synapse/rest/media/config_resource.py
index 80462d65d3..b014e91bdb 100644 --- a/synapse/rest/media/config_resource.py +++ b/synapse/rest/media/config_resource.py
@@ -40,7 +40,14 @@ class MediaConfigResource(RestServlet): self.clock = hs.get_clock() self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.media.max_upload_size} + self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository async def on_GET(self, request: SynapseRequest) -> None: - await self.auth.get_user_by_req(request) - respond_with_json(request, 200, self.limits_dict, send_cors=True) + requester = await self.auth.get_user_by_req(request) + user_specific_config = ( + await self.media_repository_callbacks.get_media_config_for_user( + requester.user.to_string() + ) + ) + response = user_specific_config if user_specific_config else self.limits_dict + respond_with_json(request, 200, response, send_cors=True) diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py
index 5ef6bf8836..572f7897fd 100644 --- a/synapse/rest/media/upload_resource.py +++ b/synapse/rest/media/upload_resource.py
@@ -50,9 +50,12 @@ class BaseUploadServlet(RestServlet): self.server_name = hs.hostname self.auth = hs.get_auth() self.max_upload_size = hs.config.media.max_upload_size + self._media_repository_callbacks = ( + hs.get_module_api_callbacks().media_repository + ) - def _get_file_metadata( - self, request: SynapseRequest + async def _get_file_metadata( + self, request: SynapseRequest, user_id: str ) -> Tuple[int, Optional[str], str]: raw_content_length = request.getHeader("Content-Length") if raw_content_length is None: @@ -67,7 +70,14 @@ class BaseUploadServlet(RestServlet): code=413, errcode=Codes.TOO_LARGE, ) - + if not await self._media_repository_callbacks.is_user_allowed_to_upload_media_of_size( + user_id, content_length + ): + raise SynapseError( + msg="Upload request body is too large", + code=413, + errcode=Codes.TOO_LARGE, + ) args: Dict[bytes, List[bytes]] = request.args # type: ignore upload_name_bytes = parse_bytes_from_args(args, "filename") if upload_name_bytes: @@ -94,7 +104,7 @@ class BaseUploadServlet(RestServlet): # if headers.hasHeader(b"Content-Disposition"): # disposition = headers.getRawHeaders(b"Content-Disposition")[0] - # TODO(markjh): parse content-dispostion + # TODO(markjh): parse content-disposition return content_length, upload_name, media_type @@ -104,7 +114,9 @@ class UploadServlet(BaseUploadServlet): async def on_POST(self, request: SynapseRequest) -> None: requester = await self.auth.get_user_by_req(request) - content_length, upload_name, media_type = self._get_file_metadata(request) + content_length, upload_name, media_type = await self._get_file_metadata( + request, requester.user.to_string() + ) try: content: IO = request.content # type: ignore @@ -152,7 +164,9 @@ class AsyncUploadServlet(BaseUploadServlet): async with lock: await self.media_repo.verify_can_upload(media_id, requester.user) - content_length, upload_name, media_type = self._get_file_metadata(request) + content_length, upload_name, media_type = await self._get_file_metadata( + request, requester.user.to_string() + ) try: content: IO = request.content # type: ignore diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index 7b5bfc0421..982b6c0e7e 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py
@@ -29,7 +29,6 @@ from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource from synapse.rest.synapse.client.sso_register import SsoRegisterResource -from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource if TYPE_CHECKING: from synapse.server import HomeServer @@ -50,9 +49,7 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc "/_synapse/client/pick_idp": PickIdpResource(hs), "/_synapse/client/pick_username": pick_username_resource(hs), "/_synapse/client/new_user_consent": NewUserConsentResource(hs), - "/_synapse/client/sso_register": SsoRegisterResource(hs), - # Unsubscribe to notification emails link - "/_synapse/client/unsubscribe": UnsubscribeResource(hs), + "/_synapse/client/sso_register": SsoRegisterResource(hs) } # Expose the JWKS endpoint if OAuth2 delegation is enabled @@ -68,16 +65,6 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc resources["/_synapse/client/oidc"] = OIDCResource(hs) - if hs.config.saml2.saml2_enabled: - from synapse.rest.synapse.client.saml2 import SAML2Resource - - res = SAML2Resource(hs) - resources["/_synapse/client/saml2"] = res - - # This is also mounted under '/_matrix' for backwards-compatibility. - # To be removed in Synapse v1.32.0. - resources["/_matrix/saml2"] = res - if hs.config.federation.federation_whitelist_endpoint_enabled: resources[FederationWhitelistResource.PATH] = FederationWhitelistResource(hs) diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py deleted file mode 100644
index 29e4b2d07a..0000000000 --- a/synapse/rest/synapse/client/password_reset.py +++ /dev/null
@@ -1,129 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2020 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# -import logging -from typing import TYPE_CHECKING, Tuple - -from twisted.web.server import Request - -from synapse.api.errors import ThreepidValidationError -from synapse.http.server import DirectServeHtmlResource -from synapse.http.servlet import parse_string -from synapse.util.stringutils import assert_valid_client_secret - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class PasswordResetSubmitTokenResource(DirectServeHtmlResource): - """Handles 3PID validation token submission - - This resource gets mounted under /_synapse/client/password_reset/email/submit_token - """ - - isLeaf = 1 - - def __init__(self, hs: "HomeServer"): - """ - Args: - hs: server - """ - super().__init__() - - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - - self._confirmation_email_template = ( - hs.config.email.email_password_reset_template_confirmation_html - ) - self._email_password_reset_template_success_html = ( - hs.config.email.email_password_reset_template_success_html_content - ) - self._failure_email_template = ( - hs.config.email.email_password_reset_template_failure_html - ) - - # This resource should only be mounted if email validation is enabled - assert hs.config.email.can_verify_email - - async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]: - sid = parse_string(request, "sid", required=True) - token = parse_string(request, "token", required=True) - client_secret = parse_string(request, "client_secret", required=True) - assert_valid_client_secret(client_secret) - - # Show a confirmation page, just in case someone accidentally clicked this link when - # they didn't mean to - template_vars = { - "sid": sid, - "token": token, - "client_secret": client_secret, - } - return ( - 200, - self._confirmation_email_template.render(**template_vars).encode("utf-8"), - ) - - async def _async_render_POST(self, request: Request) -> Tuple[int, bytes]: - sid = parse_string(request, "sid", required=True) - token = parse_string(request, "token", required=True) - client_secret = parse_string(request, "client_secret", required=True) - - # Attempt to validate a 3PID session - try: - # Mark the session as valid - next_link = await self.store.validate_threepid_session( - sid, client_secret, token, self.clock.time_msec() - ) - - # Perform a 302 redirect if next_link is set - if next_link: - if next_link.startswith("file:///"): - logger.warning( - "Not redirecting to next_link as it is a local file: address" - ) - else: - next_link_bytes = next_link.encode("utf-8") - request.setHeader("Location", next_link_bytes) - return ( - 302, - ( - b'You are being redirected to <a href="%s">%s</a>.' - % (next_link_bytes, next_link_bytes) - ), - ) - - # Otherwise show the success template - html_bytes = self._email_password_reset_template_success_html.encode( - "utf-8" - ) - status_code = 200 - except ThreepidValidationError as e: - status_code = e.code - - # Show a failure page with a reason - template_vars = {"failure_reason": e.msg} - html_bytes = self._failure_email_template.render(**template_vars).encode( - "utf-8" - ) - - return status_code, html_bytes diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
index f26929bd60..5e599f85b0 100644 --- a/synapse/rest/synapse/client/pick_idp.py +++ b/synapse/rest/synapse/client/pick_idp.py
@@ -21,6 +21,7 @@ import logging from typing import TYPE_CHECKING +from synapse.api.urls import LoginSSORedirectURIBuilder from synapse.http.server import ( DirectServeHtmlResource, finish_request, @@ -49,6 +50,8 @@ class PickIdpResource(DirectServeHtmlResource): hs.config.sso.sso_login_idp_picker_template ) self._server_name = hs.hostname + self._public_baseurl = hs.config.server.public_baseurl + self._login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config) async def _async_render_GET(self, request: SynapseRequest) -> None: client_redirect_url = parse_string( @@ -56,25 +59,23 @@ class PickIdpResource(DirectServeHtmlResource): ) idp = parse_string(request, "idp", required=False) - # if we need to pick an IdP, do so + # If we need to pick an IdP, do so if not idp: return await self._serve_id_picker(request, client_redirect_url) - # otherwise, redirect to the IdP's redirect URI - providers = self._sso_handler.get_identity_providers() - auth_provider = providers.get(idp) - if not auth_provider: - logger.info("Unknown idp %r", idp) - self._sso_handler.render_error( - request, "unknown_idp", "Unknown identity provider ID" + # Otherwise, redirect to the login SSO redirect endpoint for the given IdP + # (which will in turn take us to the the IdP's redirect URI). + # + # We could go directly to the IdP's redirect URI, but this way we ensure that + # the user goes through the same logic as normal flow. Additionally, if a proxy + # needs to intercept the request, it only needs to intercept the one endpoint. + sso_login_redirect_url = ( + self._login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id=idp, client_redirect_url=client_redirect_url ) - return - - sso_url = await auth_provider.handle_redirect_request( - request, client_redirect_url.encode("utf8") ) - logger.info("Redirecting to %s", sso_url) - request.redirect(sso_url) + logger.info("Redirecting to %s", sso_login_redirect_url) + request.redirect(sso_login_redirect_url) finish_request(request) async def _serve_id_picker( diff --git a/synapse/rest/synapse/client/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py deleted file mode 100644
index 3658c6a0e3..0000000000 --- a/synapse/rest/synapse/client/saml2/__init__.py +++ /dev/null
@@ -1,42 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -import logging -from typing import TYPE_CHECKING - -from twisted.web.resource import Resource - -from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource -from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class SAML2Resource(Resource): - def __init__(self, hs: "HomeServer"): - Resource.__init__(self) - self.putChild(b"metadata.xml", SAML2MetadataResource(hs)) - self.putChild(b"authn_response", SAML2ResponseResource(hs)) - - -__all__ = ["SAML2Resource"] diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py deleted file mode 100644
index bcd5195108..0000000000 --- a/synapse/rest/synapse/client/saml2/metadata_resource.py +++ /dev/null
@@ -1,46 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -from typing import TYPE_CHECKING - -import saml2.metadata - -from twisted.web.resource import Resource -from twisted.web.server import Request - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class SAML2MetadataResource(Resource): - """A Twisted web resource which renders the SAML metadata""" - - isLeaf = 1 - - def __init__(self, hs: "HomeServer"): - Resource.__init__(self) - self.sp_config = hs.config.saml2.saml2_sp_config - - def render_GET(self, request: Request) -> bytes: - metadata_xml = saml2.metadata.create_metadata_string( - configfile=None, config=self.sp_config - ) - request.setHeader(b"Content-Type", b"text/xml; charset=utf-8") - return metadata_xml diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py deleted file mode 100644
index 7b8667e04c..0000000000 --- a/synapse/rest/synapse/client/saml2/response_resource.py +++ /dev/null
@@ -1,52 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -from typing import TYPE_CHECKING - -from twisted.web.server import Request - -from synapse.http.server import DirectServeHtmlResource -from synapse.http.site import SynapseRequest - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class SAML2ResponseResource(DirectServeHtmlResource): - """A Twisted web resource which handles the SAML response""" - - isLeaf = 1 - - def __init__(self, hs: "HomeServer"): - super().__init__() - self._saml_handler = hs.get_saml_handler() - self._sso_handler = hs.get_sso_handler() - - async def _async_render_GET(self, request: Request) -> None: - # We're not expecting any GET request on that resource if everything goes right, - # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. - # In this case, just tell the user that something went wrong and they should - # try to authenticate again. - self._sso_handler.render_error( - request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" - ) - - async def _async_render_POST(self, request: SynapseRequest) -> None: - await self._saml_handler.handle_saml_response(request) diff --git a/synapse/rest/synapse/client/unsubscribe.py b/synapse/rest/synapse/client/unsubscribe.py deleted file mode 100644
index 6d4bd9f2ed..0000000000 --- a/synapse/rest/synapse/client/unsubscribe.py +++ /dev/null
@@ -1,88 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2022 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -from typing import TYPE_CHECKING - -from synapse.api.errors import StoreError -from synapse.http.server import DirectServeHtmlResource, respond_with_html_bytes -from synapse.http.servlet import parse_string -from synapse.http.site import SynapseRequest - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class UnsubscribeResource(DirectServeHtmlResource): - """ - To allow pusher to be delete by clicking a link (ie. GET request) - """ - - SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>" - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.notifier = hs.get_notifier() - self.auth = hs.get_auth() - self.pusher_pool = hs.get_pusherpool() - self.macaroon_generator = hs.get_macaroon_generator() - - async def _async_render_GET(self, request: SynapseRequest) -> None: - """ - Handle a user opening an unsubscribe link in the browser, either via an - HTML/Text email or via the List-Unsubscribe header. - """ - token = parse_string(request, "access_token", required=True) - app_id = parse_string(request, "app_id", required=True) - pushkey = parse_string(request, "pushkey", required=True) - - user_id = self.macaroon_generator.verify_delete_pusher_token( - token, app_id, pushkey - ) - - try: - await self.pusher_pool.remove_pusher( - app_id=app_id, pushkey=pushkey, user_id=user_id - ) - except StoreError as se: - if se.code != 404: - # This is fine: they're already unsubscribed - raise - - self.notifier.on_new_replication_data() - - respond_with_html_bytes( - request, - 200, - UnsubscribeResource.SUCCESS_HTML, - ) - - async def _async_render_POST(self, request: SynapseRequest) -> None: - """ - Handle a mail user agent POSTing to the unsubscribe URL via the - List-Unsubscribe & List-Unsubscribe-Post headers. - """ - - # TODO Assert that the body has a single field - - # Assert the body has form encoded key/value pair of - # List-Unsubscribe=One-Click. - - await self._async_render_GET(request) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index d0ca8ca46b..9ce1eb6249 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py
@@ -18,12 +18,13 @@ # # import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Tuple, cast from twisted.web.resource import Resource from twisted.web.server import Request -from synapse.http.server import set_cors_headers +from synapse.api.errors import NotFoundError +from synapse.http.server import DirectServeJsonResource from synapse.http.site import SynapseRequest from synapse.types import JsonDict from synapse.util import json_encoder @@ -38,27 +39,30 @@ logger = logging.getLogger(__name__) class WellKnownBuilder: def __init__(self, hs: "HomeServer"): self._config = hs.config + self._auth = hs.get_auth() - def get_well_known(self) -> Optional[JsonDict]: + async def get_well_known(self) -> Optional[JsonDict]: if not self._config.server.serve_client_wellknown: return None result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}} - if self._config.registration.default_identity_server: - result["m.identity_server"] = { - "base_url": self._config.registration.default_identity_server - } - # We use the MSC3861 values as they are used by multiple MSCs if self._config.experimental.msc3861.enabled: + # If MSC3861 is enabled, we can assume self._auth is an instance of MSC3861DelegatedAuth + # We import lazily here because of the authlib requirement + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + auth = cast(MSC3861DelegatedAuth, self._auth) + result["org.matrix.msc2965.authentication"] = { - "issuer": self._config.experimental.msc3861.issuer + "issuer": await auth.issuer(), } - if self._config.experimental.msc3861.account_management_url is not None: - result["org.matrix.msc2965.authentication"][ - "account" - ] = self._config.experimental.msc3861.account_management_url + account_management_url = await auth.account_management_url() + if account_management_url is not None: + result["org.matrix.msc2965.authentication"]["account"] = ( + account_management_url + ) if self._config.server.extra_well_known_client_content: for ( @@ -71,26 +75,22 @@ class WellKnownBuilder: return result -class ClientWellKnownResource(Resource): +class ClientWellKnownResource(DirectServeJsonResource): """A Twisted web resource which renders the .well-known/matrix/client file""" isLeaf = 1 def __init__(self, hs: "HomeServer"): - Resource.__init__(self) + super().__init__() self._well_known_builder = WellKnownBuilder(hs) - def render_GET(self, request: SynapseRequest) -> bytes: - set_cors_headers(request) - r = self._well_known_builder.get_well_known() + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + r = await self._well_known_builder.get_well_known() if not r: - request.setResponseCode(404) - request.setHeader(b"Content-Type", b"text/plain") - return b".well-known not available" + raise NotFoundError(".well-known not available") logger.debug("returning: %s", r) - request.setHeader(b"Content-Type", b"application/json") - return json_encoder.encode(r).encode("utf-8") + return 200, r class ServerWellKnownResource(Resource): diff --git a/synapse/server.py b/synapse/server.py
index 46b9d83a04..03cca9eeed 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -34,6 +34,7 @@ from typing_extensions import TypeAlias from twisted.internet.interfaces import IOpenSSLContextFactory from twisted.internet.tcp import Port +from twisted.python.threadpool import ThreadPool from twisted.web.iweb import IPolicyForHTTPS from twisted.web.resource import Resource @@ -65,8 +66,8 @@ from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.admin import AdminHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, PasswordAuthProvider -from synapse.handlers.cas import CasHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler +from synapse.handlers.delayed_events import DelayedEventsHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.directory import DirectoryHandler @@ -76,7 +77,6 @@ from synapse.handlers.event_auth import EventAuthHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.federation import FederationHandler from synapse.handlers.federation_event import FederationEventHandler -from synapse.handlers.identity import IdentityHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.pagination import PaginationHandler @@ -105,9 +105,9 @@ from synapse.handlers.room_member import ( RoomMemberMasterHandler, ) from synapse.handlers.room_member_worker import RoomMemberWorkerHandler +from synapse.handlers.room_policy import RoomPolicyHandler from synapse.handlers.room_summary import RoomSummaryHandler from synapse.handlers.search import SearchHandler -from synapse.handlers.send_email import SendEmailHandler from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.sliding_sync import SlidingSyncHandler from synapse.handlers.sso import SsoHandler @@ -123,6 +123,7 @@ from synapse.http.client import ( ) from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.media.media_repository import MediaRepository +from synapse.metrics import register_threadpool from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager from synapse.module_api import ModuleApi from synapse.module_api.callbacks import ModuleApiCallbacks @@ -160,7 +161,6 @@ if TYPE_CHECKING: from synapse.handlers.jwt import JwtHandler from synapse.handlers.oidc import OidcHandler - from synapse.handlers.saml import SamlHandler from synapse.storage._base import SQLBaseStore @@ -246,9 +246,12 @@ class HomeServer(metaclass=abc.ABCMeta): """ REQUIRED_ON_BACKGROUND_TASK_STARTUP = [ + "admin", "account_validity", "auth", "deactivate_account", + "delayed_events", + "e2e_keys", # for the `delete_old_otks` scheduled-task handler "message", "pagination", "profile", @@ -385,7 +388,7 @@ class HomeServer(metaclass=abc.ABCMeta): def is_mine(self, domain_specific_string: DomainSpecificString) -> bool: return domain_specific_string.domain == self.hostname - def is_mine_id(self, string: str) -> bool: + def is_mine_id(self, user_id: str) -> bool: """Determines whether a user ID or room alias originates from this homeserver. Returns: @@ -393,7 +396,7 @@ class HomeServer(metaclass=abc.ABCMeta): homeserver. `False` otherwise, or if the user ID or room alias is malformed. """ - localpart_hostname = string.split(":", 1) + localpart_hostname = user_id.split(":", 1) if len(localpart_hostname) < 2: return False return localpart_hostname[1] == self.hostname @@ -633,10 +636,6 @@ class HomeServer(metaclass=abc.ABCMeta): return FederationEventHandler(self) @cache_in_self - def get_identity_handler(self) -> IdentityHandler: - return IdentityHandler(self) - - @cache_in_self def get_initial_sync_handler(self) -> InitialSyncHandler: return InitialSyncHandler(self) @@ -657,10 +656,6 @@ class HomeServer(metaclass=abc.ABCMeta): return SearchHandler(self) @cache_in_self - def get_send_email_handler(self) -> SendEmailHandler: - return SendEmailHandler(self) - - @cache_in_self def get_set_password_handler(self) -> SetPasswordHandler: return SetPasswordHandler(self) @@ -786,22 +781,16 @@ class HomeServer(metaclass=abc.ABCMeta): return AccountValidityHandler(self) @cache_in_self - def get_cas_handler(self) -> CasHandler: - return CasHandler(self) - - @cache_in_self - def get_saml_handler(self) -> "SamlHandler": - from synapse.handlers.saml import SamlHandler - - return SamlHandler(self) - - @cache_in_self def get_oidc_handler(self) -> "OidcHandler": from synapse.handlers.oidc import OidcHandler return OidcHandler(self) @cache_in_self + def get_room_policy_handler(self) -> RoomPolicyHandler: + return RoomPolicyHandler(self) + + @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: return EventClientSerializer(self) @@ -941,3 +930,28 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_task_scheduler(self) -> TaskScheduler: return TaskScheduler(self) + + @cache_in_self + def get_media_sender_thread_pool(self) -> ThreadPool: + """Fetch the threadpool used to read files when responding to media + download requests.""" + + # We can choose a large threadpool size as these threads predominately + # do IO rather than CPU work. + media_threadpool = ThreadPool( + name="media_threadpool", minthreads=1, maxthreads=50 + ) + + media_threadpool.start() + self.get_reactor().addSystemEventTrigger( + "during", "shutdown", media_threadpool.stop + ) + + # Register the threadpool with our metrics. + register_threadpool("media", media_threadpool) + + return media_threadpool + + @cache_in_self + def get_delayed_events_handler(self) -> DelayedEventsHandler: + return DelayedEventsHandler(self) diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index f6ea90bd4f..e88e8c9b45 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -119,7 +119,9 @@ class ResourceLimitsServerNotices: elif not currently_blocked and limit_msg: # Room is not notifying of a block, when it ought to be. await self._apply_limit_block_notification( - user_id, limit_msg, limit_type # type: ignore + user_id, + limit_msg, + limit_type, # type: ignore ) except SynapseError as e: logger.error("Error sending resource limits server notice: %s", e) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 72b291889b..9e48e09270 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py
@@ -59,11 +59,13 @@ from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func +from synapse.util.stringutils import shortstr if TYPE_CHECKING: from synapse.server import HomeServer from synapse.storage.controllers import StateStorageController from synapse.storage.databases.main import DataStore + from synapse.storage.databases.state.deletion import StateDeletionDataStore logger = logging.getLogger(__name__) metrics_logger = logging.getLogger("synapse.state.metrics") @@ -194,6 +196,8 @@ class StateHandler: self._storage_controllers = hs.get_storage_controllers() self._events_shard_config = hs.config.worker.events_shard_config self._instance_name = hs.get_instance_name() + self._state_store = hs.get_datastores().state + self._state_deletion_store = hs.get_datastores().state_deletion self._update_current_state_client = ( ReplicationUpdateCurrentStateRestServlet.make_client(hs) @@ -355,6 +359,28 @@ class StateHandler: await_full_state=False, ) + # Ensure we still have the state groups we're relying on, and bump + # their usage time to avoid them being deleted from under us. + if entry.state_group: + missing_state_group = await self._state_deletion_store.check_state_groups_and_bump_deletion( + {entry.state_group} + ) + if missing_state_group: + raise Exception(f"Missing state group: {entry.state_group}") + elif entry.prev_group: + # We only rely on the prev group when persisting the event if we + # don't have an `entry.state_group`. + missing_state_group = await self._state_deletion_store.check_state_groups_and_bump_deletion( + {entry.prev_group} + ) + + if missing_state_group: + # If we're missing the prev group then we can just clear the + # entries, and rely on `entry._state` (which must exist if + # `entry.state_group` is None) + entry.prev_group = None + entry.delta_ids = None + state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids state_ids_before_event = None @@ -475,7 +501,10 @@ class StateHandler: @trace @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: StrCollection, await_full_state: bool = True + self, + room_id: str, + event_ids: StrCollection, + await_full_state: bool = True, ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -511,6 +540,7 @@ class StateHandler: ) = await self._state_storage_controller.get_state_group_delta( state_group_id ) + return _StateCacheEntry( state=None, state_group=state_group_id, @@ -531,7 +561,9 @@ class StateHandler: room_version, state_to_resolve, None, - state_res_store=StateResolutionStore(self.store), + state_res_store=StateResolutionStore( + self.store, self._state_deletion_store + ), ) return result @@ -663,7 +695,25 @@ class StateResolutionHandler: async with self.resolve_linearizer.queue(group_names): cache = self._state_cache.get(group_names, None) if cache: - return cache + # Check that the returned cache entry doesn't point to deleted + # state groups. + state_groups_to_check = set() + if cache.state_group is not None: + state_groups_to_check.add(cache.state_group) + + if cache.prev_group is not None: + state_groups_to_check.add(cache.prev_group) + + missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion( + state_groups_to_check + ) + + if not missing_state_groups: + return cache + else: + # There are missing state groups, so let's remove the stale + # entry and continue as if it was a cache miss. + self._state_cache.pop(group_names, None) logger.info( "Resolving state for %s with groups %s", @@ -671,6 +721,16 @@ class StateResolutionHandler: list(group_names), ) + # We double check that none of the state groups have been deleted. + # They shouldn't be as all these state groups should be referenced. + missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion( + group_names + ) + if missing_state_groups: + raise Exception( + f"State groups have been deleted: {shortstr(missing_state_groups)}" + ) + state_groups_histogram.observe(len(state_groups_ids)) new_state = await self.resolve_events_with_store( @@ -884,7 +944,8 @@ class StateResolutionStore: in well defined way. """ - store: "DataStore" + main_store: "DataStore" + state_deletion_store: "StateDeletionDataStore" def get_events( self, event_ids: StrCollection, allow_rejected: bool = False @@ -899,7 +960,7 @@ class StateResolutionStore: An awaitable which resolves to a dict from event_id to event. """ - return self.store.get_events( + return self.main_store.get_events( event_ids, redact_behaviour=EventRedactBehaviour.as_is, get_prev_content=False, @@ -920,4 +981,4 @@ class StateResolutionStore: An awaitable that resolves to a set of event IDs. """ - return self.store.get_auth_chain_difference(room_id, state_sets) + return self.main_store.get_auth_chain_difference(room_id, state_sets) diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index da926ad146..d0c0a9fc96 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py
@@ -29,15 +29,15 @@ from typing import ( Generator, Iterable, List, + Literal, Optional, + Protocol, Sequence, Set, Tuple, overload, ) -from typing_extensions import Literal, Protocol - from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e12ab94576..b5fe7dd858 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -23,8 +23,11 @@ import logging from abc import ABCMeta from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union -from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401 -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + make_in_list_sql_clause, # noqa: F401 +) from synapse.types import get_domain_from_id from synapse.util import json_decoder from synapse.util.caches.descriptors import CachedFunction @@ -83,7 +86,9 @@ class SQLBaseStore(metaclass=ABCMeta): """ def _invalidate_state_caches( - self, room_id: str, members_changed: Collection[str] + self, + room_id: str, + members_changed: Collection[str], ) -> None: """Invalidates caches that are based on the current state, but does not stream invalidations down replication. @@ -109,6 +114,7 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache( "get_number_joined_users_in_room", (room_id,) ) + self._attempt_to_invalidate_cache("get_member_counts", (room_id,)) self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,)) # There's no easy way of invalidating this cache for just the users @@ -123,12 +129,18 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache( "_get_rooms_for_local_user_where_membership_is_inner", (user_id,) ) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", (user_id,) + ) # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,)) self._attempt_to_invalidate_cache("get_room_type", (room_id,)) self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) def _invalidate_state_caches_all(self, room_id: str) -> None: """Invalidates caches that are based on the current state, but does @@ -147,6 +159,7 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,)) self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,)) self._attempt_to_invalidate_cache("get_number_joined_users_in_room", (room_id,)) + self._attempt_to_invalidate_cache("get_member_counts", (room_id,)) self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,)) self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None) self._attempt_to_invalidate_cache("get_user_in_room_with_profile", None) @@ -157,6 +170,9 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_room_type", (room_id,)) self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index f473294070..d170bbddaa 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -40,20 +40,15 @@ from typing import ( import attr -from synapse._pydantic_compat import HAS_PYDANTIC_V2 +from synapse._pydantic_compat import BaseModel from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection, Cursor -from synapse.types import JsonDict +from synapse.types import JsonDict, StrCollection from synapse.util import Clock, json_encoder from . import engines -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import BaseModel -else: - from pydantic import BaseModel - if TYPE_CHECKING: from synapse.server import HomeServer from synapse.storage.database import ( @@ -487,6 +482,31 @@ class BackgroundUpdater: return not update_exists + async def have_completed_background_updates( + self, update_names: StrCollection + ) -> bool: + """Return the name of background updates that have not yet been + completed""" + if self._all_done: + return True + + # We now check if we have completed all pending background updates. We + # do this as once this returns True then it will set `self._all_done` + # and we can skip checking the database in future. + if await self.has_completed_background_updates(): + return True + + rows = await self.db_pool.simple_select_many_batch( + table="background_updates", + column="update_name", + iterable=update_names, + retcols=("update_name",), + desc="get_uncompleted_background_updates", + ) + + # If we find any rows then we've not completed the update. + return not bool(rows) + async def do_next_background_update(self, sleep: bool = True) -> bool: """Does some amount of work on the next queued background update @@ -719,9 +739,9 @@ class BackgroundUpdater: c.execute(sql) async def updater(progress: JsonDict, batch_size: int) -> int: - assert isinstance( - self.db_pool.engine, engines.PostgresEngine - ), "validate constraint background update registered for non-Postres database" + assert isinstance(self.db_pool.engine, engines.PostgresEngine), ( + "validate constraint background update registered for non-Postres database" + ) logger.info("Validating constraint %s to %s", constraint_name, table) await self.db_pool.runWithConnection(runner) @@ -769,7 +789,7 @@ class BackgroundUpdater: # we may already have a half-built index. Let's just drop it # before trying to create it again. - sql = "DROP INDEX IF EXISTS %s" % (index_name,) + sql = "DROP INDEX CONCURRENTLY IF EXISTS %s" % (index_name,) logger.debug("[SQL] %s", sql) c.execute(sql) @@ -794,7 +814,7 @@ class BackgroundUpdater: if replaces_index is not None: # We drop the old index as the new index has now been created. - sql = f"DROP INDEX IF EXISTS {replaces_index}" + sql = f"DROP INDEX CONCURRENTLY IF EXISTS {replaces_index}" logger.debug("[SQL] %s", sql) c.execute(sql) finally: @@ -880,9 +900,9 @@ class BackgroundUpdater: on the table. Used to iterate over the table. """ - assert isinstance( - self.db_pool.engine, engines.PostgresEngine - ), "validate constraint background update registered for non-Postres database" + assert isinstance(self.db_pool.engine, engines.PostgresEngine), ( + "validate constraint background update registered for non-Postres database" + ) async def updater(progress: JsonDict, batch_size: int) -> int: return await self.validate_constraint_and_delete_in_background( diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index d0e015bf19..f5131fe291 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py
@@ -332,6 +332,7 @@ class EventsPersistenceStorageController: # store for now. self.main_store = stores.main self.state_store = stores.state + self._state_deletion_store = stores.state_deletion assert stores.persist_events self.persist_events_store = stores.persist_events @@ -416,7 +417,7 @@ class EventsPersistenceStorageController: set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled)) async def enqueue( - item: Tuple[str, List[Tuple[EventBase, EventContext]]] + item: Tuple[str, List[Tuple[EventBase, EventContext]]], ) -> Dict[str, str]: room_id, evs_ctxs = item return await self._event_persist_queue.add_to_queue( @@ -502,8 +503,15 @@ class EventsPersistenceStorageController: """ state = await self._calculate_current_state(room_id) delta = await self._calculate_state_delta(room_id, state) + sliding_sync_table_changes = ( + await self.persist_events_store._calculate_sliding_sync_table_changes( + room_id, [], delta + ) + ) - await self.persist_events_store.update_current_state(room_id, delta) + await self.persist_events_store.update_current_state( + room_id, delta, sliding_sync_table_changes + ) async def _calculate_current_state(self, room_id: str) -> StateMap[str]: """Calculate the current state of a room, based on the forward extremities @@ -542,7 +550,9 @@ class EventsPersistenceStorageController: room_version, state_maps_by_state_group, event_map=None, - state_res_store=StateResolutionStore(self.main_store), + state_res_store=StateResolutionStore( + self.main_store, self._state_deletion_store + ), ) return await res.get_state(self._state_controller, StateFilter.all()) @@ -628,15 +638,20 @@ class EventsPersistenceStorageController: room_id, [e for e, _ in chunk] ) - await self.persist_events_store._persist_events_and_state_updates( - room_id, - chunk, - state_delta_for_room=state_delta_for_room, - new_forward_extremities=new_forward_extremities, - use_negative_stream_ordering=backfilled, - inhibit_local_membership_updates=backfilled, - new_event_links=new_event_links, - ) + # Stop the state groups from being deleted while we're persisting + # them. + async with self._state_deletion_store.persisting_state_group_references( + events_and_contexts + ): + await self.persist_events_store._persist_events_and_state_updates( + room_id, + chunk, + state_delta_for_room=state_delta_for_room, + new_forward_extremities=new_forward_extremities, + use_negative_stream_ordering=backfilled, + inhibit_local_membership_updates=backfilled, + new_event_links=new_event_links, + ) return replaced_events @@ -785,9 +800,9 @@ class EventsPersistenceStorageController: ) # Remove any events which are prev_events of any existing events. - existing_prevs: Collection[str] = ( - await self.persist_events_store._get_events_which_are_prevs(result) - ) + existing_prevs: Collection[ + str + ] = await self.persist_events_store._get_events_which_are_prevs(result) result.difference_update(existing_prevs) # Finally handle the case where the new events have soft-failed prev @@ -855,8 +870,7 @@ class EventsPersistenceStorageController: # This should only happen for outlier events. if not ev.internal_metadata.is_outlier(): raise Exception( - "Context for new event %s has no state " - "group" % (ev.event_id,) + "Context for new event %s has no state group" % (ev.event_id,) ) continue if ctx.state_group_deltas: @@ -958,7 +972,9 @@ class EventsPersistenceStorageController: room_version, state_groups, events_map, - state_res_store=StateResolutionStore(self.main_store), + state_res_store=StateResolutionStore( + self.main_store, self._state_deletion_store + ), ) state_resolutions_during_persistence.inc() diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py
index e794b370c2..c2d4bf8290 100644 --- a/synapse/storage/controllers/purge_events.py +++ b/synapse/storage/controllers/purge_events.py
@@ -21,10 +21,19 @@ import itertools import logging -from typing import TYPE_CHECKING, Set +from typing import ( + TYPE_CHECKING, + Collection, + Mapping, + Optional, + Set, +) from synapse.logging.context import nested_logging_context +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage.database import LoggingTransaction from synapse.storage.databases import Databases +from synapse.types.storage import _BackgroundUpdates if TYPE_CHECKING: from synapse.server import HomeServer @@ -38,12 +47,22 @@ class PurgeEventsStorageController: def __init__(self, hs: "HomeServer", stores: Databases): self.stores = stores + if hs.config.worker.run_background_tasks: + self._delete_state_loop_call = hs.get_clock().looping_call( + self._delete_state_groups_loop, 60 * 1000 + ) + + self.stores.state.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.MARK_UNREFERENCED_STATE_GROUPS_FOR_DELETION_BG_UPDATE, + self._background_delete_unrefereneced_state_groups, + ) + async def purge_room(self, room_id: str) -> None: """Deletes all record of a room""" with nested_logging_context(room_id): - state_groups_to_delete = await self.stores.main.purge_room(room_id) - await self.stores.state.purge_room_state(room_id, state_groups_to_delete) + await self.stores.main.purge_room(room_id) + await self.stores.state.purge_room_state(room_id) async def purge_history( self, room_id: str, token: str, delete_local_events: bool @@ -68,11 +87,16 @@ class PurgeEventsStorageController: logger.info("[purge] finding state groups that can be deleted") sg_to_delete = await self._find_unreferenced_groups(state_groups) - await self.stores.state.purge_unreferenced_state_groups( - room_id, sg_to_delete + # Mark these state groups as pending deletion, they will actually + # get deleted automatically later. + await self.stores.state_deletion.mark_state_groups_as_pending_deletion( + sg_to_delete ) - async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: + async def _find_unreferenced_groups( + self, + state_groups: Collection[int], + ) -> Set[int]: """Used when purging history to figure out which state groups can be deleted. @@ -118,6 +142,307 @@ class PurgeEventsStorageController: next_to_search |= prevs state_groups_seen |= prevs + # We also check to see if anything referencing the state groups are + # also unreferenced. This helps ensure that we delete unreferenced + # state groups, if we don't then we will de-delta them when we + # delete the other state groups leading to increased DB usage. + next_edges = await self.stores.state.get_next_state_groups(current_search) + nexts = set(next_edges.keys()) + nexts -= state_groups_seen + next_to_search |= nexts + state_groups_seen |= nexts + to_delete = state_groups_seen - referenced_groups return to_delete + + @wrap_as_background_process("_delete_state_groups_loop") + async def _delete_state_groups_loop(self) -> None: + """Background task that deletes any state groups that may be pending + deletion.""" + + while True: + next_to_delete = await self.stores.state_deletion.get_next_state_group_collection_to_delete() + if next_to_delete is None: + break + + (room_id, groups_to_sequences) = next_to_delete + made_progress = await self._delete_state_groups( + room_id, groups_to_sequences + ) + + # If no progress was made in deleting the state groups, then we + # break to allow a pause before trying again next time we get + # called. + if not made_progress: + break + + async def _delete_state_groups( + self, room_id: str, groups_to_sequences: Mapping[int, int] + ) -> bool: + """Tries to delete the given state groups. + + Returns: + Whether we made progress in deleting the state groups (or marking + them as referenced). + """ + + # We double check if any of the state groups have become referenced. + # This shouldn't happen, as any usages should cause the state group to + # be removed as pending deletion. + referenced_state_groups = await self.stores.main.get_referenced_state_groups( + groups_to_sequences + ) + + if referenced_state_groups: + # We mark any state groups that have become referenced as being + # used. + await self.stores.state_deletion.mark_state_groups_as_used( + referenced_state_groups + ) + + # Update list of state groups to remove referenced ones + groups_to_sequences = { + state_group: sequence_number + for state_group, sequence_number in groups_to_sequences.items() + if state_group not in referenced_state_groups + } + + if not groups_to_sequences: + # We made progress here as long as we marked some state groups as + # now referenced. + return len(referenced_state_groups) > 0 + + return await self.stores.state.purge_unreferenced_state_groups( + room_id, + groups_to_sequences, + ) + + async def _background_delete_unrefereneced_state_groups( + self, progress: dict, batch_size: int + ) -> int: + """This background update will slowly delete any unreferenced state groups""" + + last_checked_state_group = progress.get("last_checked_state_group") + + if last_checked_state_group is None: + # This is the first run. + last_checked_state_group = ( + await self.stores.state.db_pool.simple_select_one_onecol( + table="state_groups", + keyvalues={}, + retcol="MAX(id)", + allow_none=True, + desc="get_max_state_group", + ) + ) + if last_checked_state_group is None: + # There are no state groups so the background process is finished. + await self.stores.state.db_pool.updates._end_background_update( + _BackgroundUpdates.MARK_UNREFERENCED_STATE_GROUPS_FOR_DELETION_BG_UPDATE + ) + return batch_size + last_checked_state_group += 1 + + ( + last_checked_state_group, + final_batch, + ) = await self._delete_unreferenced_state_groups_batch( + last_checked_state_group, + batch_size, + ) + + if not final_batch: + # There are more state groups to check. + progress = { + "last_checked_state_group": last_checked_state_group, + } + await self.stores.state.db_pool.updates._background_update_progress( + _BackgroundUpdates.MARK_UNREFERENCED_STATE_GROUPS_FOR_DELETION_BG_UPDATE, + progress, + ) + else: + # This background process is finished. + await self.stores.state.db_pool.updates._end_background_update( + _BackgroundUpdates.MARK_UNREFERENCED_STATE_GROUPS_FOR_DELETION_BG_UPDATE + ) + + return batch_size + + async def _delete_unreferenced_state_groups_batch( + self, + last_checked_state_group: int, + batch_size: int, + ) -> tuple[int, bool]: + """Looks for unreferenced state groups starting from the last state group + checked and marks them for deletion. + + Args: + last_checked_state_group: The last state group that was checked. + batch_size: How many state groups to process in this iteration. + + Returns: + (last_checked_state_group, final_batch) + """ + + # Find all state groups that can be deleted if any of the original set are deleted. + ( + to_delete, + last_checked_state_group, + final_batch, + ) = await self._find_unreferenced_groups_for_background_deletion( + last_checked_state_group, batch_size + ) + + if len(to_delete) == 0: + return last_checked_state_group, final_batch + + await self.stores.state_deletion.mark_state_groups_as_pending_deletion( + to_delete + ) + + return last_checked_state_group, final_batch + + async def _find_unreferenced_groups_for_background_deletion( + self, + last_checked_state_group: int, + batch_size: int, + ) -> tuple[Set[int], int, bool]: + """Used when deleting unreferenced state groups in the background to figure out + which state groups can be deleted. + To avoid increased DB usage due to de-deltaing state groups, this returns only + state groups which are free standing (ie. no shared edges with referenced groups) or + state groups which do not share edges which result in a future referenced group. + + The following scenarios outline the possibilities based on state group data in + the DB. + + ie. Free standing -> state groups 1-N would be returned: + SG_1 + | + ... + | + SG_N + + ie. Previous reference -> state groups 2-N would be returned: + SG_1 <- referenced by event + | + SG_2 + | + ... + | + SG_N + + ie. Future reference -> none of the following state groups would be returned: + SG_1 + | + SG_2 + | + ... + | + SG_N <- referenced by event + + Args: + last_checked_state_group: The last state group that was checked. + batch_size: How many state groups to process in this iteration. + + Returns: + (to_delete, last_checked_state_group, final_batch) + """ + + # If a state group's next edge is not pending deletion then we don't delete the state group. + # If there is no next edge or the next edges are all marked for deletion, then delete + # the state group. + # This holds since we walk backwards from the latest state groups, ensuring that + # we've already checked newer state groups for event references along the way. + def get_next_state_groups_marked_for_deletion_txn( + txn: LoggingTransaction, + ) -> tuple[dict[int, bool], dict[int, int]]: + state_group_sql = """ + SELECT s.id, e.state_group, d.state_group + FROM ( + SELECT id FROM state_groups + WHERE id < ? ORDER BY id DESC LIMIT ? + ) as s + LEFT JOIN state_group_edges AS e ON (s.id = e.prev_state_group) + LEFT JOIN state_groups_pending_deletion AS d ON (e.state_group = d.state_group) + """ + txn.execute(state_group_sql, (last_checked_state_group, batch_size)) + + # Mapping from state group to whether we should delete it. + state_groups_to_deletion: dict[int, bool] = {} + + # Mapping from state group to prev state group. + state_groups_to_prev: dict[int, int] = {} + + for row in txn: + state_group = row[0] + next_edge = row[1] + pending_deletion = row[2] + + if next_edge is not None: + state_groups_to_prev[next_edge] = state_group + + if next_edge is not None and not pending_deletion: + # We have found an edge not marked for deletion. + # Check previous results to see if this group is part of a chain + # within this batch that qualifies for deletion. + # ie. batch contains: + # SG_1 -> SG_2 -> SG_3 + # If SG_3 is a candidate for deletion, then SG_2 & SG_1 should also + # be, even though they have edges which may not be marked for + # deletion. + # This relies on SQL results being sorted in DESC order to work. + next_is_deletion_candidate = state_groups_to_deletion.get(next_edge) + if ( + next_is_deletion_candidate is None + or not next_is_deletion_candidate + ): + state_groups_to_deletion[state_group] = False + else: + state_groups_to_deletion.setdefault(state_group, True) + else: + # This state group may be a candidate for deletion + state_groups_to_deletion.setdefault(state_group, True) + + return state_groups_to_deletion, state_groups_to_prev + + ( + state_groups_to_deletion, + state_group_edges, + ) = await self.stores.state.db_pool.runInteraction( + "get_next_state_groups_marked_for_deletion", + get_next_state_groups_marked_for_deletion_txn, + ) + deletion_candidates = { + state_group + for state_group, deletion in state_groups_to_deletion.items() + if deletion + } + + final_batch = False + state_groups = state_groups_to_deletion.keys() + if len(state_groups) < batch_size: + final_batch = True + else: + last_checked_state_group = min(state_groups) + + if len(state_groups) == 0: + return set(), last_checked_state_group, final_batch + + # Determine if any of the remaining state groups are directly referenced. + referenced = await self.stores.main.get_referenced_state_groups( + deletion_candidates + ) + + # Remove state groups from deletion_candidates which are directly referenced or share a + # future edge with a referenced state group within this batch. + def filter_reference_chains(group: Optional[int]) -> None: + while group is not None: + deletion_candidates.discard(group) + group = state_group_edges.get(group) + + for referenced_group in referenced: + filter_reference_chains(referenced_group) + + return deletion_candidates, last_checked_state_group, final_batch diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index b50eb8868e..f28f5d7e03 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py
@@ -234,8 +234,11 @@ class StateStorageController: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ + if state_filter is None: + state_filter = StateFilter.all() + await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + if not state_filter.must_await_full_state(self._is_mine_id): await_full_state = False event_to_groups = await self.get_state_group_for_events( @@ -244,7 +247,7 @@ class StateStorageController: groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() + groups, state_filter ) state_event_map = await self.stores.main.get_events( @@ -292,10 +295,11 @@ class StateStorageController: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - if ( - await_full_state - and state_filter - and not state_filter.must_await_full_state(self._is_mine_id) + if state_filter is None: + state_filter = StateFilter.all() + + if await_full_state and not state_filter.must_await_full_state( + self._is_mine_id ): # Full state is not required if the state filter is restrictive enough. await_full_state = False @@ -306,7 +310,7 @@ class StateStorageController: groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() + groups, state_filter ) event_to_state = { @@ -335,9 +339,10 @@ class StateStorageController: RuntimeError if we don't have a state group for the event (ie it is an outlier or is unknown) """ - state_map = await self.get_state_for_events( - [event_id], state_filter or StateFilter.all() - ) + if state_filter is None: + state_filter = StateFilter.all() + + state_map = await self.get_state_for_events([event_id], state_filter) return state_map[event_id] @trace @@ -365,9 +370,12 @@ class StateStorageController: RuntimeError if we don't have a state group for the event (ie it is an outlier or is unknown) """ + if state_filter is None: + state_filter = StateFilter.all() + state_map = await self.get_state_ids_for_events( [event_id], - state_filter or StateFilter.all(), + state_filter, await_full_state=await_full_state, ) return state_map[event_id] @@ -388,9 +396,12 @@ class StateStorageController: at the event and `state_filter` is not satisfied by partial state. Defaults to `True`. """ + if state_filter is None: + state_filter = StateFilter.all() + state_ids = await self.get_state_ids_for_event( event_id, - state_filter=state_filter or StateFilter.all(), + state_filter=state_filter, await_full_state=await_full_state, ) @@ -426,6 +437,9 @@ class StateStorageController: at the last event in the room before `stream_position` and `state_filter` is not satisfied by partial state. Defaults to `True`. """ + if state_filter is None: + state_filter = StateFilter.all() + # FIXME: This gets the state at the latest event before the stream ordering, # which might not be the same as the "current state" of the room at the time # of the stream token if there were multiple forward extremities at the time. @@ -442,7 +456,7 @@ class StateStorageController: if last_event_id: state = await self.get_state_after_event( last_event_id, - state_filter=state_filter or StateFilter.all(), + state_filter=state_filter, await_full_state=await_full_state, ) @@ -500,9 +514,10 @@ class StateStorageController: Returns: Dict of state group to state map. """ - return await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) + if state_filter is None: + state_filter = StateFilter.all() + + return await self.stores.state._get_state_for_groups(groups, state_filter) @trace @tag_args @@ -583,12 +598,13 @@ class StateStorageController: Returns: The current state of the room. """ - if await_full_state and ( - not state_filter or state_filter.must_await_full_state(self._is_mine_id) - ): + if state_filter is None: + state_filter = StateFilter.all() + + if await_full_state and state_filter.must_await_full_state(self._is_mine_id): await self._partial_state_room_tracker.await_full_state(room_id) - if state_filter and not state_filter.is_full(): + if state_filter is not None and not state_filter.is_full(): return await self.stores.main.get_partial_filtered_current_state_ids( room_id, state_filter ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 569f618193..a4941e58f6 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -35,6 +35,8 @@ from typing import ( Iterable, Iterator, List, + Literal, + Mapping, Optional, Sequence, Tuple, @@ -46,7 +48,7 @@ from typing import ( import attr from prometheus_client import Counter, Histogram -from typing_extensions import Concatenate, Literal, ParamSpec +from typing_extensions import Concatenate, ParamSpec from twisted.enterprise import adbapi from twisted.internet.interfaces import IReactorCore @@ -64,6 +66,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor, SQLQueryParameters +from synapse.types import StrCollection from synapse.util.async_helpers import delay_cancellation from synapse.util.iterutils import batch_iter @@ -1095,6 +1098,48 @@ class DatabasePool: txn.execute(sql, vals) + @staticmethod + def simple_insert_returning_txn( + txn: LoggingTransaction, + table: str, + values: Dict[str, Any], + returning: StrCollection, + ) -> Tuple[Any, ...]: + """Executes a `INSERT INTO... RETURNING...` statement (or equivalent for + SQLite versions that don't support it). + """ + + if txn.database_engine.supports_returning: + sql = "INSERT INTO %s (%s) VALUES(%s) RETURNING %s" % ( + table, + ", ".join(k for k in values.keys()), + ", ".join("?" for _ in values.keys()), + ", ".join(k for k in returning), + ) + + txn.execute(sql, list(values.values())) + row = txn.fetchone() + assert row is not None + return row + else: + # For old versions of SQLite we do a standard insert and then can + # use `last_insert_rowid` to get at the row we just inserted + DatabasePool.simple_insert_txn( + txn, + table=table, + values=values, + ) + txn.execute("SELECT last_insert_rowid()") + row = txn.fetchone() + assert row is not None + (rowid,) = row + + row = DatabasePool.simple_select_one_txn( + txn, table=table, keyvalues={"rowid": rowid}, retcols=returning + ) + assert row is not None + return row + async def simple_insert_many( self, table: str, @@ -1254,9 +1299,9 @@ class DatabasePool: self, txn: LoggingTransaction, table: str, - keyvalues: Dict[str, Any], - values: Dict[str, Any], - insertion_values: Optional[Dict[str, Any]] = None, + keyvalues: Mapping[str, Any], + values: Mapping[str, Any], + insertion_values: Optional[Mapping[str, Any]] = None, where_clause: Optional[str] = None, ) -> bool: """ @@ -1299,9 +1344,9 @@ class DatabasePool: self, txn: LoggingTransaction, table: str, - keyvalues: Dict[str, Any], - values: Dict[str, Any], - insertion_values: Optional[Dict[str, Any]] = None, + keyvalues: Mapping[str, Any], + values: Mapping[str, Any], + insertion_values: Optional[Mapping[str, Any]] = None, where_clause: Optional[str] = None, lock: bool = True, ) -> bool: @@ -1322,7 +1367,7 @@ class DatabasePool: if lock: # We need to lock the table :( - self.engine.lock_table(txn, table) + txn.database_engine.lock_table(txn, table) def _getwhere(key: str) -> str: # If the value we're passing in is None (aka NULL), we need to use @@ -1376,13 +1421,13 @@ class DatabasePool: # successfully inserted return True + @staticmethod def simple_upsert_txn_native_upsert( - self, txn: LoggingTransaction, table: str, - keyvalues: Dict[str, Any], - values: Dict[str, Any], - insertion_values: Optional[Dict[str, Any]] = None, + keyvalues: Mapping[str, Any], + values: Mapping[str, Any], + insertion_values: Optional[Mapping[str, Any]] = None, where_clause: Optional[str] = None, ) -> bool: """ @@ -1535,8 +1580,8 @@ class DatabasePool: self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False) + @staticmethod def simple_upsert_many_txn_native_upsert( - self, txn: LoggingTransaction, table: str, key_names: Collection[str], @@ -1966,8 +2011,8 @@ class DatabasePool: def simple_update_txn( txn: LoggingTransaction, table: str, - keyvalues: Dict[str, Any], - updatevalues: Dict[str, Any], + keyvalues: Mapping[str, Any], + updatevalues: Mapping[str, Any], ) -> int: """ Update rows in the given database table. @@ -2115,10 +2160,26 @@ class DatabasePool: if rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - # Ideally we could use the overload decorator here to specify that the - # return type is only optional if allow_none is True, but this does not work - # when you call a static method from an instance. - # See https://github.com/python/mypy/issues/7781 + @overload + @staticmethod + def simple_select_one_txn( + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcols: Collection[str], + allow_none: Literal[False] = False, + ) -> Tuple[Any, ...]: ... + + @overload + @staticmethod + def simple_select_one_txn( + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcols: Collection[str], + allow_none: Literal[True] = True, + ) -> Optional[Tuple[Any, ...]]: ... + @staticmethod def simple_select_one_txn( txn: LoggingTransaction, diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index dd9fc01fb0..81886ff765 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py
@@ -26,6 +26,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.databases.state import StateGroupDataStore +from synapse.storage.databases.state.deletion import StateDeletionDataStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database @@ -49,12 +50,14 @@ class Databases(Generic[DataStoreT]): main state persist_events + state_deletion """ databases: List[DatabasePool] main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class` state: StateGroupDataStore persist_events: Optional[PersistEventsStore] + state_deletion: StateDeletionDataStore def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): # Note we pass in the main store class here as workers use a different main @@ -63,6 +66,7 @@ class Databases(Generic[DataStoreT]): self.databases = [] main: Optional[DataStoreT] = None state: Optional[StateGroupDataStore] = None + state_deletion: Optional[StateDeletionDataStore] = None persist_events: Optional[PersistEventsStore] = None for database_config in hs.config.database.databases: @@ -114,7 +118,8 @@ class Databases(Generic[DataStoreT]): if state: raise Exception("'state' data store already configured") - state = StateGroupDataStore(database, db_conn, hs) + state_deletion = StateDeletionDataStore(database, db_conn, hs) + state = StateGroupDataStore(database, db_conn, hs, state_deletion) db_conn.commit() @@ -135,7 +140,7 @@ class Databases(Generic[DataStoreT]): if not main: raise Exception("No 'main' database configured") - if not state: + if not state or not state_deletion: raise Exception("No 'state' database configured") # We use local variables here to ensure that the databases do not have @@ -143,3 +148,4 @@ class Databases(Generic[DataStoreT]): self.main = main # type: ignore[assignment] self.state = state self.persist_events = persist_events + self.state_deletion = state_deletion diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 586e84f2a4..86431f6e40 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -3,7 +3,7 @@ # # Copyright 2019-2021 The Matrix.org Foundation C.I.C. # Copyright 2014-2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -33,6 +33,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.sliding_sync import SlidingSyncStore from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor @@ -43,6 +44,7 @@ from .appservice import ApplicationServiceStore, ApplicationServiceTransactionSt from .cache import CacheInvalidationWorkerStore from .censor_events import CensorEventsStore from .client_ips import ClientIpWorkerStore +from .delayed_events import DelayedEventsStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore from .directory import DirectoryStore @@ -156,6 +158,8 @@ class DataStore( LockStore, SessionStore, TaskSchedulerWorkerStore, + SlidingSyncStore, + DelayedEventsStore, ): def __init__( self, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 966393869b..715815cc09 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -34,6 +34,7 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes +from synapse.api.errors import Codes, SynapseError from synapse.replication.tcp.streams import AccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -43,6 +44,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore +from synapse.storage.invite_rule import InviteRulesConfig from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder @@ -102,6 +104,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self._delete_account_data_for_deactivated_users, ) + self._msc4155_enabled = hs.config.experimental.msc4155_enabled + def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream @@ -177,7 +181,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_room_account_data_for_user_txn( txn: LoggingTransaction, - ) -> Dict[str, Dict[str, JsonDict]]: + ) -> Dict[str, Dict[str, JsonMapping]]: # The 'content != '{}' condition below prevents us from using # `simple_select_list_txn` here, as it doesn't support conditions # other than 'equals'. @@ -194,7 +198,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) txn.execute(sql, (user_id,)) - by_room: Dict[str, Dict[str, JsonDict]] = {} + by_room: Dict[str, Dict[str, JsonMapping]] = {} for room_id, account_data_type, content in txn: room_data = by_room.setdefault(room_id, {}) @@ -394,7 +398,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def get_updated_global_account_data_for_user( self, user_id: str, stream_id: int - ) -> Mapping[str, JsonMapping]: + ) -> Dict[str, JsonMapping]: """Get all the global account_data that's changed for a user. Args: @@ -407,7 +411,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_updated_global_account_data_for_user( txn: LoggingTransaction, - ) -> Dict[str, JsonDict]: + ) -> Dict[str, JsonMapping]: sql = """ SELECT account_data_type, content FROM account_data WHERE user_id = ? AND stream_id > ? @@ -429,7 +433,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def get_updated_room_account_data_for_user( self, user_id: str, stream_id: int - ) -> Dict[str, Dict[str, JsonDict]]: + ) -> Dict[str, Dict[str, JsonMapping]]: """Get all the room account_data that's changed for a user. Args: @@ -442,14 +446,14 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_updated_room_account_data_for_user_txn( txn: LoggingTransaction, - ) -> Dict[str, Dict[str, JsonDict]]: + ) -> Dict[str, Dict[str, JsonMapping]]: sql = """ SELECT room_id, account_data_type, content FROM room_account_data WHERE user_id = ? AND stream_id > ? """ txn.execute(sql, (user_id, stream_id)) - account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} + account_data_by_room: Dict[str, Dict[str, JsonMapping]] = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) @@ -467,6 +471,56 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) get_updated_room_account_data_for_user_txn, ) + async def get_updated_room_account_data_for_user_for_room( + self, + # Since there are multiple arguments with the same type, force keyword arguments + # so people don't accidentally swap the order + *, + user_id: str, + room_id: str, + from_stream_id: int, + to_stream_id: int, + ) -> Dict[str, JsonMapping]: + """Get the room account_data that's changed for a user in a room. + + (> `from_stream_id` and <= `to_stream_id`) + + Args: + user_id: The user to get the account_data for. + room_id: The room to check + from_stream_id: The point in the stream to fetch from + to_stream_id: The point in the stream to fetch to + + Returns: + A dict of the room account data. + """ + + def get_updated_room_account_data_for_user_for_room_txn( + txn: LoggingTransaction, + ) -> Dict[str, JsonMapping]: + sql = """ + SELECT account_data_type, content FROM room_account_data + WHERE user_id = ? AND room_id = ? AND stream_id > ? AND stream_id <= ? + """ + txn.execute(sql, (user_id, room_id, from_stream_id, to_stream_id)) + + room_account_data: Dict[str, JsonMapping] = {} + for row in txn: + room_account_data[row[0]] = db_to_json(row[1]) + + return room_account_data + + changed = self._account_data_stream_cache.has_entity_changed( + user_id, int(from_stream_id) + ) + if not changed: + return {} + + return await self.db_pool.runInteraction( + "get_updated_room_account_data_for_user_for_room", + get_updated_room_account_data_for_user_for_room_txn, + ) + @cached(max_entries=5000, iterable=True) async def ignored_by(self, user_id: str) -> FrozenSet[str]: """ @@ -507,6 +561,23 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) ) + async def get_invite_config_for_user(self, user_id: str) -> InviteRulesConfig: + """ + Get the invite configuration for the current user. + + Args: + user_id: + """ + + if not self._msc4155_enabled: + # This equates to allowing all invites, as if the setting was off. + return InviteRulesConfig(None) + + data = await self.get_global_account_data_by_type_for_user( + user_id, AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG + ) + return InviteRulesConfig(data) + def process_replication_rows( self, stream_name: str, @@ -710,6 +781,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) else: currently_ignored_users = set() + if user_id in currently_ignored_users: + raise SynapseError(400, "You cannot ignore yourself", Codes.INVALID_PARAM) + # If the data has not changed, nothing to do. if previously_ignored_users == currently_ignored_users: return diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 63624f3e8f..9418fb6dd7 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -41,6 +41,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.events import SLIDING_SYNC_RELEVANT_STATE_SET from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.util.caches.descriptors import CachedFunction @@ -218,6 +219,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): room_id = row.keys[0] members_changed = set(row.keys[1:]) self._invalidate_state_caches(room_id, members_changed) + self._curr_state_delta_stream_cache.entity_has_changed( # type: ignore[attr-defined] + room_id, token + ) + for user_id in members_changed: + self._membership_stream_cache.entity_has_changed(user_id, token) # type: ignore[attr-defined] elif row.cache_func == PURGE_HISTORY_CACHE_NAME: if row.keys is None: raise Exception( @@ -235,6 +241,35 @@ class CacheInvalidationWorkerStore(SQLBaseStore): room_id = row.keys[0] self._invalidate_caches_for_room_events(room_id) self._invalidate_caches_for_room(room_id) + self._curr_state_delta_stream_cache.entity_has_changed( # type: ignore[attr-defined] + room_id, token + ) + # Note: This code is commented out to improve cache performance. + # While uncommenting would provide complete correctness, our + # automatic forgotten room purge logic (see + # `forgotten_room_retention_period`) means this would frequently + # clear the entire cache (effectively) and probably have a noticable + # impact on the cache hit ratio. + # + # Not updating the cache here is safe because: + # + # 1. `_membership_stream_cache` is only used to indicate the + # *absence* of changes, i.e. "nothing has changed between tokens + # X and Y and so return early and don't query the database". + # 2. `_membership_stream_cache` is used when we query data from + # `current_state_delta_stream` and `room_memberships` but since + # nothing new is written to the database for those tables when + # purging/deleting a room (only deleting rows), there is nothing + # changed to care about. + # + # At worst, the cache might indicate a change at token X, at which + # point, we will query the database and discover nothing is there. + # + # Ideally, we would make it so that we could clear the cache on a + # more granular level but that's a bit complex and fiddly to do with + # room membership. + # + # self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined] else: self._attempt_to_invalidate_cache(row.cache_func, row.keys) @@ -271,20 +306,33 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "get_rooms_for_user", (data.state_key,) ) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) + self._membership_stream_cache.entity_has_changed(data.state_key, token) # type: ignore[attr-defined] elif data.type == EventTypes.RoomEncryption: self._attempt_to_invalidate_cache( "get_room_encryption", (data.room_id,) ) elif data.type == EventTypes.Create: self._attempt_to_invalidate_cache("get_room_type", (data.room_id,)) + + if (data.type, data.state_key) in SLIDING_SYNC_RELEVANT_STATE_SET: + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) elif row.type == EventsStreamAllStateRow.TypeId: assert isinstance(data, EventsStreamAllStateRow) # Similar to the above, but the entire caches are invalidated. This is # unfortunate for the membership caches, but should recover quickly. self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined] + self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined] self._attempt_to_invalidate_cache("get_rooms_for_user", None) self._attempt_to_invalidate_cache("get_room_type", (data.room_id,)) self._attempt_to_invalidate_cache("get_room_encryption", (data.room_id,)) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) else: raise Exception("Unknown events stream row type %s" % (row.type,)) @@ -312,6 +360,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "get_unread_event_push_actions_by_room_for_user", (room_id,) ) + self._attempt_to_invalidate_cache("get_metadata_for_event", (room_id, event_id)) + + self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,)) # The `_get_membership_from_event_id` is immutable, except for the # case where we look up an event *before* persisting it. @@ -344,6 +395,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "_get_rooms_for_local_user_where_membership_is_inner", (state_key,) ) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", + (state_key,), + ) self._attempt_to_invalidate_cache( "did_forget", @@ -360,6 +415,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): elif etype == EventTypes.RoomEncryption: self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) + if (etype, state_key) in SLIDING_SYNC_RELEVANT_STATE_SET: + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) + if relates_to: self._attempt_to_invalidate_cache( "get_relations_for_event", @@ -404,6 +464,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,)) + self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,)) + self._attempt_to_invalidate_cache("_get_membership_from_event_id", None) self._attempt_to_invalidate_cache("get_applicable_edit", None) self._attempt_to_invalidate_cache("get_thread_id", None) @@ -413,6 +475,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "_get_rooms_for_local_user_where_membership_is_inner", None ) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) self._attempt_to_invalidate_cache("did_forget", None) self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None) self._attempt_to_invalidate_cache("get_references_for_event", None) @@ -425,6 +490,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("_get_state_group_for_event", None) self._attempt_to_invalidate_cache("get_event_ordering", None) + self._attempt_to_invalidate_cache("get_metadata_for_event", (room_id,)) self._attempt_to_invalidate_cache("is_partial_state_event", None) self._attempt_to_invalidate_cache("_get_joined_profile_from_event_id", None) @@ -450,6 +516,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_account_data_for_room", None) self._attempt_to_invalidate_cache("get_account_data_for_room_and_type", None) + self._attempt_to_invalidate_cache("get_tags_for_room", None) self._attempt_to_invalidate_cache("get_aliases_for_room", (room_id,)) self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,)) self._attempt_to_invalidate_cache("_get_forward_extremeties_for_room", None) @@ -469,6 +536,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "get_current_hosts_in_room_ordered", (room_id,) ) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) self._attempt_to_invalidate_cache("did_forget", None) self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None) self._attempt_to_invalidate_cache("_get_membership_from_event_id", None) @@ -476,6 +546,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_room_type", (room_id,)) self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) + self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,)) + # And delete state caches. self._invalidate_state_caches_all(room_id) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 4b66247640..69008804bd 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py
@@ -20,10 +20,19 @@ # import logging -from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Dict, + List, + Mapping, + Optional, + Tuple, + TypedDict, + Union, + cast, +) import attr -from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore @@ -238,9 +247,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): INNER JOIN user_ips USING (user_id, access_token, ip) GROUP BY user_id, access_token, ip HAVING count(*) > 1 - """.format( - clause - ), + """.format(clause), args, ) res = cast( @@ -373,9 +380,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): LIMIT ? ) c INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) - """ % { - "where_clause": where_clause - } + """ % {"where_clause": where_clause} txn.execute(sql, where_args + [batch_size]) rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) @@ -645,9 +650,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke @wrap_as_background_process("update_client_ips") async def _update_client_ips_batch(self) -> None: - assert ( - self._update_on_this_worker - ), "This worker is not designated to update client IPs" + assert self._update_on_this_worker, ( + "This worker is not designated to update client IPs" + ) # If the DB pool has already terminated, don't try updating if not self.db_pool.is_running(): @@ -666,9 +671,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke txn: LoggingTransaction, to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]], ) -> None: - assert ( - self._update_on_this_worker - ), "This worker is not designated to update client IPs" + assert self._update_on_this_worker, ( + "This worker is not designated to update client IPs" + ) # Keys and values for the `user_ips` upsert. user_ips_keys = [] diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py new file mode 100644
index 0000000000..c88682d55c --- /dev/null +++ b/synapse/storage/databases/main/delayed_events.py
@@ -0,0 +1,549 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +import logging +from typing import List, NewType, Optional, Tuple + +import attr + +from synapse.api.errors import NotFoundError +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import LoggingTransaction, StoreError +from synapse.storage.engines import PostgresEngine +from synapse.types import JsonDict, RoomID +from synapse.util import json_encoder, stringutils as stringutils + +logger = logging.getLogger(__name__) + + +DelayID = NewType("DelayID", str) +UserLocalpart = NewType("UserLocalpart", str) +DeviceID = NewType("DeviceID", str) +EventType = NewType("EventType", str) +StateKey = NewType("StateKey", str) + +Delay = NewType("Delay", int) +Timestamp = NewType("Timestamp", int) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventDetails: + room_id: RoomID + type: EventType + state_key: Optional[StateKey] + origin_server_ts: Optional[Timestamp] + content: JsonDict + device_id: Optional[DeviceID] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DelayedEventDetails(EventDetails): + delay_id: DelayID + user_localpart: UserLocalpart + + +class DelayedEventsStore(SQLBaseStore): + async def get_delayed_events_stream_pos(self) -> int: + """ + Gets the stream position of the background process to watch for state events + that target the same piece of state as any pending delayed events. + """ + return await self.db_pool.simple_select_one_onecol( + table="delayed_events_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="get_delayed_events_stream_pos", + ) + + async def update_delayed_events_stream_pos(self, stream_id: Optional[int]) -> None: + """ + Updates the stream position of the background process to watch for state events + that target the same piece of state as any pending delayed events. + + Must only be used by the worker running the background process. + """ + await self.db_pool.simple_update_one( + table="delayed_events_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_delayed_events_stream_pos", + ) + + async def add_delayed_event( + self, + *, + user_localpart: str, + device_id: Optional[str], + creation_ts: Timestamp, + room_id: str, + event_type: str, + state_key: Optional[str], + origin_server_ts: Optional[int], + content: JsonDict, + delay: int, + ) -> Tuple[DelayID, Timestamp]: + """ + Inserts a new delayed event in the DB. + + Returns: The generated ID assigned to the added delayed event, + and the send time of the next delayed event to be sent, + which is either the event just added or one added earlier. + """ + delay_id = _generate_delay_id() + send_ts = Timestamp(creation_ts + delay) + + def add_delayed_event_txn(txn: LoggingTransaction) -> Timestamp: + self.db_pool.simple_insert_txn( + txn, + table="delayed_events", + values={ + "delay_id": delay_id, + "user_localpart": user_localpart, + "device_id": device_id, + "delay": delay, + "send_ts": send_ts, + "room_id": room_id, + "event_type": event_type, + "state_key": state_key, + "origin_server_ts": origin_server_ts, + "content": json_encoder.encode(content), + }, + ) + + next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + assert next_send_ts is not None + return next_send_ts + + next_send_ts = await self.db_pool.runInteraction( + "add_delayed_event", add_delayed_event_txn + ) + + return delay_id, next_send_ts + + async def restart_delayed_event( + self, + *, + delay_id: str, + user_localpart: str, + current_ts: Timestamp, + ) -> Timestamp: + """ + Restarts the send time of the matching delayed event, + as long as it hasn't already been marked for processing. + + Args: + delay_id: The ID of the delayed event to restart. + user_localpart: The localpart of the delayed event's owner. + current_ts: The current time, which will be used to calculate the new send time. + + Returns: The send time of the next delayed event to be sent, + which is either the event just restarted, or another one + with an earlier send time than the restarted one's new send time. + + Raises: + NotFoundError: if there is no matching delayed event. + """ + + def restart_delayed_event_txn( + txn: LoggingTransaction, + ) -> Timestamp: + txn.execute( + """ + UPDATE delayed_events + SET send_ts = ? + delay + WHERE delay_id = ? AND user_localpart = ? + AND NOT is_processed + """, + ( + current_ts, + delay_id, + user_localpart, + ), + ) + if txn.rowcount == 0: + raise NotFoundError("Delayed event not found") + + next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + assert next_send_ts is not None + return next_send_ts + + return await self.db_pool.runInteraction( + "restart_delayed_event", restart_delayed_event_txn + ) + + async def get_all_delayed_events_for_user( + self, + user_localpart: str, + ) -> List[JsonDict]: + """Returns all pending delayed events owned by the given user.""" + # TODO: Support Pagination stream API ("next_batch" field) + rows = await self.db_pool.execute( + "get_all_delayed_events_for_user", + """ + SELECT + delay_id, + room_id, + event_type, + state_key, + delay, + send_ts, + content + FROM delayed_events + WHERE user_localpart = ? AND NOT is_processed + ORDER BY send_ts + """, + user_localpart, + ) + return [ + { + "delay_id": DelayID(row[0]), + "room_id": str(RoomID.from_string(row[1])), + "type": EventType(row[2]), + **({"state_key": StateKey(row[3])} if row[3] is not None else {}), + "delay": Delay(row[4]), + "running_since": Timestamp(row[5] - row[4]), + "content": db_to_json(row[6]), + } + for row in rows + ] + + async def process_timeout_delayed_events( + self, current_ts: Timestamp + ) -> Tuple[ + List[DelayedEventDetails], + Optional[Timestamp], + ]: + """ + Marks for processing all delayed events that should have been sent prior to the provided time + that haven't already been marked as such. + + Returns: The details of all newly-processed delayed events, + and the send time of the next delayed event to be sent, if any. + """ + + def process_timeout_delayed_events_txn( + txn: LoggingTransaction, + ) -> Tuple[ + List[DelayedEventDetails], + Optional[Timestamp], + ]: + sql_cols = ", ".join( + ( + "delay_id", + "user_localpart", + "room_id", + "event_type", + "state_key", + "origin_server_ts", + "send_ts", + "content", + "device_id", + ) + ) + sql_update = "UPDATE delayed_events SET is_processed = TRUE" + sql_where = "WHERE send_ts <= ? AND NOT is_processed" + sql_args = (current_ts,) + sql_order = "ORDER BY send_ts" + if isinstance(self.database_engine, PostgresEngine): + # Do this only in Postgres because: + # - SQLite's RETURNING emits rows in an arbitrary order + # - https://www.sqlite.org/lang_returning.html#limitations_and_caveats + # - SQLite does not support data-modifying statements in a WITH clause + # - https://www.sqlite.org/lang_with.html + # - https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-MODIFYING + txn.execute( + f""" + WITH events_to_send AS ( + {sql_update} {sql_where} RETURNING * + ) SELECT {sql_cols} FROM events_to_send {sql_order} + """, + sql_args, + ) + rows = txn.fetchall() + else: + txn.execute( + f"SELECT {sql_cols} FROM delayed_events {sql_where} {sql_order}", + sql_args, + ) + rows = txn.fetchall() + txn.execute(f"{sql_update} {sql_where}", sql_args) + assert txn.rowcount == len(rows) + + events = [ + DelayedEventDetails( + RoomID.from_string(row[2]), + EventType(row[3]), + StateKey(row[4]) if row[4] is not None else None, + # If no custom_origin_ts is set, use send_ts as the event's timestamp + Timestamp(row[5] if row[5] is not None else row[6]), + db_to_json(row[7]), + DeviceID(row[8]) if row[8] is not None else None, + DelayID(row[0]), + UserLocalpart(row[1]), + ) + for row in rows + ] + next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + return events, next_send_ts + + return await self.db_pool.runInteraction( + "process_timeout_delayed_events", process_timeout_delayed_events_txn + ) + + async def process_target_delayed_event( + self, + *, + delay_id: str, + user_localpart: str, + ) -> Tuple[ + EventDetails, + Optional[Timestamp], + ]: + """ + Marks for processing the matching delayed event, regardless of its timeout time, + as long as it has not already been marked as such. + + Args: + delay_id: The ID of the delayed event to restart. + user_localpart: The localpart of the delayed event's owner. + + Returns: The details of the matching delayed event, + and the send time of the next delayed event to be sent, if any. + + Raises: + NotFoundError: if there is no matching delayed event. + """ + + def process_target_delayed_event_txn( + txn: LoggingTransaction, + ) -> Tuple[ + EventDetails, + Optional[Timestamp], + ]: + sql_cols = ", ".join( + ( + "room_id", + "event_type", + "state_key", + "origin_server_ts", + "content", + "device_id", + ) + ) + sql_update = "UPDATE delayed_events SET is_processed = TRUE" + sql_where = "WHERE delay_id = ? AND user_localpart = ? AND NOT is_processed" + sql_args = (delay_id, user_localpart) + txn.execute( + ( + f"{sql_update} {sql_where} RETURNING {sql_cols}" + if self.database_engine.supports_returning + else f"SELECT {sql_cols} FROM delayed_events {sql_where}" + ), + sql_args, + ) + row = txn.fetchone() + if row is None: + raise NotFoundError("Delayed event not found") + elif not self.database_engine.supports_returning: + txn.execute(f"{sql_update} {sql_where}", sql_args) + assert txn.rowcount == 1 + + event = EventDetails( + RoomID.from_string(row[0]), + EventType(row[1]), + StateKey(row[2]) if row[2] is not None else None, + Timestamp(row[3]) if row[3] is not None else None, + db_to_json(row[4]), + DeviceID(row[5]) if row[5] is not None else None, + ) + + return event, self._get_next_delayed_event_send_ts_txn(txn) + + return await self.db_pool.runInteraction( + "process_target_delayed_event", process_target_delayed_event_txn + ) + + async def cancel_delayed_event( + self, + *, + delay_id: str, + user_localpart: str, + ) -> Optional[Timestamp]: + """ + Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed. + + Args: + delay_id: The ID of the delayed event to restart. + user_localpart: The localpart of the delayed event's owner. + + Returns: The send time of the next delayed event to be sent, if any. + + Raises: + NotFoundError: if there is no matching delayed event. + """ + + def cancel_delayed_event_txn( + txn: LoggingTransaction, + ) -> Optional[Timestamp]: + try: + self.db_pool.simple_delete_one_txn( + txn, + table="delayed_events", + keyvalues={ + "delay_id": delay_id, + "user_localpart": user_localpart, + "is_processed": False, + }, + ) + except StoreError: + if txn.rowcount == 0: + raise NotFoundError("Delayed event not found") + else: + raise + + return self._get_next_delayed_event_send_ts_txn(txn) + + return await self.db_pool.runInteraction( + "cancel_delayed_event", cancel_delayed_event_txn + ) + + async def cancel_delayed_state_events( + self, + *, + room_id: str, + event_type: str, + state_key: str, + not_from_localpart: str, + ) -> Optional[Timestamp]: + """ + Cancels all matching delayed state events, i.e. remove them as long as they haven't been processed. + + Args: + room_id: The room ID to match against. + event_type: The event type to match against. + state_key: The state key to match against. + not_from_localpart: The localpart of a user whose delayed events to not cancel. + If set to the empty string, any users' delayed events may be cancelled. + + Returns: The send time of the next delayed event to be sent, if any. + """ + + def cancel_delayed_state_events_txn( + txn: LoggingTransaction, + ) -> Optional[Timestamp]: + txn.execute( + """ + DELETE FROM delayed_events + WHERE room_id = ? AND event_type = ? AND state_key = ? + AND user_localpart <> ? + AND NOT is_processed + """, + ( + room_id, + event_type, + state_key, + not_from_localpart, + ), + ) + return self._get_next_delayed_event_send_ts_txn(txn) + + return await self.db_pool.runInteraction( + "cancel_delayed_state_events", cancel_delayed_state_events_txn + ) + + async def delete_processed_delayed_event( + self, + delay_id: DelayID, + user_localpart: UserLocalpart, + ) -> None: + """ + Delete the matching delayed event, as long as it has been marked as processed. + + Throws: + StoreError: if there is no matching delayed event, or if it has not yet been processed. + """ + return await self.db_pool.simple_delete_one( + table="delayed_events", + keyvalues={ + "delay_id": delay_id, + "user_localpart": user_localpart, + "is_processed": True, + }, + desc="delete_processed_delayed_event", + ) + + async def delete_processed_delayed_state_events( + self, + *, + room_id: str, + event_type: str, + state_key: str, + ) -> None: + """ + Delete the matching delayed state events that have been marked as processed. + """ + await self.db_pool.simple_delete( + table="delayed_events", + keyvalues={ + "room_id": room_id, + "event_type": event_type, + "state_key": state_key, + "is_processed": True, + }, + desc="delete_processed_delayed_state_events", + ) + + async def unprocess_delayed_events(self) -> None: + """ + Unmark all delayed events for processing. + """ + await self.db_pool.simple_update( + table="delayed_events", + keyvalues={"is_processed": True}, + updatevalues={"is_processed": False}, + desc="unprocess_delayed_events", + ) + + async def get_next_delayed_event_send_ts(self) -> Optional[Timestamp]: + """ + Returns the send time of the next delayed event to be sent, if any. + """ + return await self.db_pool.runInteraction( + "get_next_delayed_event_send_ts", + self._get_next_delayed_event_send_ts_txn, + db_autocommit=True, + ) + + def _get_next_delayed_event_send_ts_txn( + self, txn: LoggingTransaction + ) -> Optional[Timestamp]: + result = self.db_pool.simple_select_one_onecol_txn( + txn, + table="delayed_events", + keyvalues={"is_processed": False}, + retcol="MIN(send_ts)", + allow_none=True, + ) + return Timestamp(result) if result is not None else None + + +def _generate_delay_id() -> DelayID: + """Generates an opaque string, for use as a delay ID""" + + # We use the following format for delay IDs: + # syd_<random string> + # They are scoped to user localparts, so it is possible for + # the same ID to exist for multiple users. + + return DelayID(f"syd_{stringutils.random_string(20)}") diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 042d595ea0..d47833655d 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -200,9 +200,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): to_stream_id=to_stream_id, ) - assert ( - last_processed_stream_id == to_stream_id - ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`" + assert last_processed_stream_id == to_stream_id, ( + "Expected _get_device_messages to process all to-device messages up to `to_stream_id`" + ) return user_id_device_id_to_messages @@ -1116,7 +1116,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): txn.execute(sql, (start, stop)) - destinations = {d for d, in txn} + destinations = {d for (d,) in txn} to_remove = set() for d in destinations: try: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 53024bddc3..6191f22cd6 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -27,6 +27,7 @@ from typing import ( Dict, Iterable, List, + Literal, Mapping, Optional, Set, @@ -35,7 +36,6 @@ from typing import ( ) from canonicaljson import encode_canonical_json -from typing_extensions import Literal from synapse.api.constants import EduTypes from synapse.api.errors import Codes, StoreError @@ -282,9 +282,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): "count_devices_by_users", count_devices_by_users_txn, user_ids ) + @cached() async def get_device( self, user_id: str, device_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Mapping[str, Any]]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -670,9 +671,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): result["keys"] = keys device_display_name = None - if ( - self.hs.config.federation.allow_device_name_lookup_over_federation - ): + if self.hs.config.federation.allow_device_name_lookup_over_federation: device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name @@ -917,7 +916,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): from_key, to_key, ) - return {u for u, in rows} + return {u for (u,) in rows} @cancellable async def get_users_whose_devices_changed( @@ -968,7 +967,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): txn.database_engine, "user_id", chunk ) txn.execute(sql % (clause,), [from_key, to_key] + args) - changes.update(user_id for user_id, in txn) + changes.update(user_id for (user_id,) in txn) return changes @@ -1093,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ), ) - results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids} + results: Dict[str, Optional[str]] = dict.fromkeys(user_ids) results.update(rows) return results @@ -1424,7 +1423,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): DELETE FROM device_lists_outbound_last_success WHERE destination = ? AND user_id = ? """ - txn.execute_batch(sql, ((row[0], row[1]) for row in rows)) + txn.execute_batch(sql, [(row[0], row[1]) for row in rows]) logger.info("Pruned %d device list outbound pokes", count) @@ -1520,7 +1519,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): args: List[Any], ) -> Set[str]: txn.execute(sql.format(clause=clause), args) - return {user_id for user_id, in txn} + return {user_id for (user_id,) in txn} changes = set() for chunk in batch_iter(changed_room_ids, 1000): @@ -1560,7 +1559,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): txn: LoggingTransaction, ) -> Set[str]: txn.execute(sql, (from_id, to_id)) - return {room_id for room_id, in txn} + return {room_id for (room_id,) in txn} return await self.db_pool.runInteraction( "get_all_device_list_changes", @@ -1819,6 +1818,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): }, desc="store_device", ) + await self.invalidate_cache_and_stream("get_device", (user_id, device_id)) + if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else @@ -1884,6 +1885,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=device_ids, keyvalues={"user_id": user_id}, ) + self._invalidate_cache_and_stream_bulk( + txn, self.get_device, [(user_id, device_id) for device_id in device_ids] + ) for batch in batch_iter(device_ids, 100): await self.db_pool.runInteraction( @@ -1917,6 +1921,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updatevalues=updates, desc="update_device", ) + await self.invalidate_cache_and_stream("get_device", (user_id, device_id)) async def update_remote_device_list_cache_entry( self, user_id: str, device_id: str, content: JsonDict, stream_id: str diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 4d6a921ab2..904ae5cb58 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -19,9 +19,18 @@ # # -from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast - -from typing_extensions import Literal, TypedDict +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Literal, + Mapping, + Optional, + Tuple, + TypedDict, + cast, +) from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace @@ -387,9 +396,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): is_verified, session_data FROM e2e_room_keys WHERE user_id = ? AND version = ? AND (%s) - """ % ( - " OR ".join(where_clauses) - ) + """ % (" OR ".join(where_clauses)) txn.execute(sql, params) @@ -512,19 +519,16 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): # it isn't there. raise StoreError(404, "No backup with that version exists") - row = cast( - Tuple[int, str, str, Optional[int]], - self.db_pool.simple_select_one_txn( - txn, - table="e2e_room_keys_versions", - keyvalues={ - "user_id": user_id, - "version": this_version, - "deleted": 0, - }, - retcols=("version", "algorithm", "auth_data", "etag"), - allow_none=False, - ), + row = self.db_pool.simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + "deleted": 0, + }, + retcols=("version", "algorithm", "auth_data", "etag"), + allow_none=False, ) return { "auth_data": db_to_json(row[2]), diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9e6c9561ae..341e7014d6 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -27,6 +27,7 @@ from typing import ( Dict, Iterable, List, + Literal, Mapping, Optional, Sequence, @@ -39,7 +40,6 @@ from typing import ( import attr from canonicaljson import encode_canonical_json -from typing_extensions import Literal from synapse.api.constants import DeviceKeyAlgorithms from synapse.appservice import ( @@ -99,6 +99,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore): unique=True, ) + self.db_pool.updates.register_background_index_update( + update_name="add_otk_ts_added_index", + index_name="e2e_one_time_keys_json_user_id_device_id_algorithm_ts_added_idx", + table="e2e_one_time_keys_json", + columns=("user_id", "device_id", "algorithm", "ts_added_ms"), + ) + class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore): def __init__( @@ -472,9 +479,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker signature_sql = """ SELECT user_id, key_id, target_device_id, signature FROM e2e_cross_signing_signatures WHERE %s - """ % ( - " OR ".join("(" + q + ")" for q in signature_query_clauses) - ) + """ % (" OR ".join("(" + q + ")" for q in signature_query_clauses)) txn.execute(signature_sql, signature_query_params) return cast( @@ -917,9 +922,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker FROM e2e_cross_signing_keys WHERE %(clause)s ORDER BY user_id, keytype, stream_id DESC - """ % { - "clause": clause - } + """ % {"clause": clause} else: # SQLite has special handling for bare columns when using # MIN/MAX with a `GROUP BY` clause where it picks the value from @@ -929,9 +932,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker FROM e2e_cross_signing_keys WHERE %(clause)s GROUP BY user_id, keytype - """ % { - "clause": clause - } + """ % {"clause": clause} txn.execute(sql, params) @@ -1128,7 +1129,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """Take a list of one time keys out of the database. Args: - query_list: An iterable of tuples of (user ID, device ID, algorithm). + query_list: An iterable of tuples of (user ID, device ID, algorithm, number of keys). Returns: A tuple (results, missing) of: @@ -1316,9 +1317,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker OTK was found. """ + # Return the oldest keys from this device (based on `ts_added_ms`). + # Doing so means that keys are issued in the same order they were uploaded, + # which reduces the chances of a client expiring its copy of a (private) + # key while the public key is still on the server, waiting to be issued. sql = """ SELECT key_id, key_json FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? + ORDER BY ts_added_ms LIMIT ? """ @@ -1360,13 +1366,22 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker A list of tuples (user_id, device_id, algorithm, key_id, key_json) for each OTK claimed. """ + # Find, delete, and return the oldest keys from each device (based on + # `ts_added_ms`). + # + # Doing so means that keys are issued in the same order they were uploaded, + # which reduces the chances of a client expiring its copy of a (private) + # key while the public key is still on the server, waiting to be issued. sql = """ WITH claims(user_id, device_id, algorithm, claim_count) AS ( VALUES ? ), ranked_keys AS ( SELECT user_id, device_id, algorithm, key_id, claim_count, - ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r + ROW_NUMBER() OVER ( + PARTITION BY (user_id, device_id, algorithm) + ORDER BY ts_added_ms + ) AS r FROM e2e_one_time_keys_json JOIN claims USING (user_id, device_id, algorithm) ) @@ -1438,6 +1453,93 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker impl, ) + async def delete_old_otks_for_next_user_batch( + self, after_user_id: str, number_of_users: int + ) -> Tuple[List[str], int]: + """Deletes old OTKs belonging to the next batch of users + + Returns: + `(users, rows)`, where: + * `users` is the user IDs of the updated users. An empty list if we are done. + * `rows` is the number of deleted rows + """ + + def impl(txn: LoggingTransaction) -> Tuple[List[str], int]: + # Find a batch of users + txn.execute( + """ + SELECT DISTINCT(user_id) FROM e2e_one_time_keys_json + WHERE user_id > ? + ORDER BY user_id + LIMIT ? + """, + (after_user_id, number_of_users), + ) + users = [row[0] for row in txn.fetchall()] + if len(users) == 0: + return users, 0 + + # Delete any old OTKs belonging to those users. + # + # We only actually consider OTKs whose key ID is 6 characters long. These + # keys were likely made by libolm rather than Vodozemac; libolm only kept + # 100 private OTKs, so was far more vulnerable than Vodozemac to throwing + # away keys prematurely. + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", users + ) + sql = f""" + DELETE FROM e2e_one_time_keys_json + WHERE {clause} AND ts_added_ms < ? AND length(key_id) = 6 + """ + args.append(self._clock.time_msec() - (7 * 24 * 3600 * 1000)) + txn.execute(sql, args) + + return users, txn.rowcount + + return await self.db_pool.runInteraction( + "delete_old_otks_for_next_user_batch", impl + ) + + async def allow_master_cross_signing_key_replacement_without_uia( + self, user_id: str, duration_ms: int + ) -> Optional[int]: + """Mark this user's latest master key as being replaceable without UIA. + + Said replacement will only be permitted for a short time after calling this + function. That time period is controlled by the duration argument. + + Returns: + None, if there is no such key. + Otherwise, the timestamp before which replacement is allowed without UIA. + """ + timestamp = self._clock.time_msec() + duration_ms + + def impl(txn: LoggingTransaction) -> Optional[int]: + txn.execute( + """ + UPDATE e2e_cross_signing_keys + SET updatable_without_uia_before_ms = ? + WHERE stream_id = ( + SELECT stream_id + FROM e2e_cross_signing_keys + WHERE user_id = ? AND keytype = 'master' + ORDER BY stream_id DESC + LIMIT 1 + ) + """, + (timestamp, user_id), + ) + if txn.rowcount == 0: + return None + + return timestamp + + return await self.db_pool.runInteraction( + "allow_master_cross_signing_key_replacement_without_uia", + impl, + ) + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def __init__( @@ -1692,42 +1794,3 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ], desc="add_e2e_signing_key", ) - - async def allow_master_cross_signing_key_replacement_without_uia( - self, user_id: str, duration_ms: int - ) -> Optional[int]: - """Mark this user's latest master key as being replaceable without UIA. - - Said replacement will only be permitted for a short time after calling this - function. That time period is controlled by the duration argument. - - Returns: - None, if there is no such key. - Otherwise, the timestamp before which replacement is allowed without UIA. - """ - timestamp = self._clock.time_msec() + duration_ms - - def impl(txn: LoggingTransaction) -> Optional[int]: - txn.execute( - """ - UPDATE e2e_cross_signing_keys - SET updatable_without_uia_before_ms = ? - WHERE stream_id = ( - SELECT stream_id - FROM e2e_cross_signing_keys - WHERE user_id = ? AND keytype = 'master' - ORDER BY stream_id DESC - LIMIT 1 - ) - """, - (timestamp, user_id), - ) - if txn.rowcount == 0: - return None - - return timestamp - - return await self.db_pool.runInteraction( - "allow_master_cross_signing_key_replacement_without_uia", - impl, - ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 715846865b..46aa5902d8 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -326,7 +326,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ rows = txn.execute_values(sql, chains.items()) - results.update(r for r, in rows) + results.update(r for (r,) in rows) else: # For SQLite we just fall back to doing a noddy for loop. sql = """ @@ -335,7 +335,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ for chain_id, max_no in chains.items(): txn.execute(sql, (chain_id, max_no)) - results.update(r for r, in txn) + results.update(r for (r,) in txn) return results @@ -645,7 +645,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ] rows = txn.execute_values(sql, args) - result.update(r for r, in rows) + result.update(r for (r,) in rows) else: # For SQLite we just fall back to doing a noddy for loop. sql = """ @@ -654,7 +654,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ for chain_id, (min_no, max_no) in chain_to_gap.items(): txn.execute(sql, (chain_id, min_no, max_no)) - result.update(r for r, in txn) + result.update(r for (r,) in txn) return result @@ -1220,13 +1220,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas HAVING count(*) > ? ORDER BY count(*) DESC LIMIT ? - """ % ( - where_clause, - ) + """ % (where_clause,) query_args = list(itertools.chain(room_id_filter, [min_count, limit])) txn.execute(sql, query_args) - return [room_id for room_id, in txn] + return [room_id for (room_id,) in txn] return await self.db_pool.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn @@ -1358,7 +1356,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (stream_ordering, room_id)) - return [event_id for event_id, in txn] + return [event_id for (event_id,) in txn] event_ids = await self.db_pool.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 0ebf5b53d5..6fb4a6df8c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -1034,97 +1034,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # one of the subqueries may have hit the limit. return notifs[:limit] - async def get_unread_push_actions_for_user_in_range_for_email( - self, - user_id: str, - min_stream_ordering: int, - max_stream_ordering: int, - limit: int = 20, - ) -> List[EmailPushAction]: - """Get a list of the most recent unread push actions for a given user, - within the given stream ordering range. Called by the emailpusher - - Args: - user_id: The user to fetch push actions for. - min_stream_ordering: The exclusive lower bound on the - stream ordering of event push actions to fetch. - max_stream_ordering: The inclusive upper bound on the - stream ordering of event push actions to fetch. - limit: The maximum number of rows to return. - Returns: - A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". - The list will be ordered by descending received_ts. - The list will have between 0~limit entries. - """ - - def get_push_actions_txn( - txn: LoggingTransaction, - ) -> List[Tuple[str, str, str, int, str, bool, int]]: - sql = """ - SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering, - ep.actions, ep.highlight, e.received_ts - FROM event_push_actions AS ep - INNER JOIN events AS e USING (room_id, event_id) - WHERE - ep.user_id = ? - AND ep.stream_ordering > ? - AND ep.stream_ordering <= ? - AND ep.notif = 1 - ORDER BY ep.stream_ordering DESC LIMIT ? - """ - txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) - return cast(List[Tuple[str, str, str, int, str, bool, int]], txn.fetchall()) - - push_actions = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn - ) - - room_ids = set() - thread_ids = [] - for ( - _, - room_id, - thread_id, - _, - _, - _, - _, - ) in push_actions: - room_ids.add(room_id) - thread_ids.append(thread_id) - - receipts_by_room = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_receipts", - self._get_receipts_for_room_and_threads_txn, - user_id=user_id, - room_ids=room_ids, - thread_ids=thread_ids, - ) - - # Make a list of dicts from the two sets of results. - notifs = [ - EmailPushAction( - event_id=event_id, - room_id=room_id, - stream_ordering=stream_ordering, - actions=_deserialize_action(actions, highlight), - received_ts=received_ts, - ) - for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions - if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread( - thread_id, stream_ordering - ) - ] - - # Now sort it so it's ordered correctly, since currently it will - # contain results from the first query, correctly ordered, followed - # by results from the second query, but we want them all ordered - # by received_ts (most recent first) - notifs.sort(key=lambda r: -(r.received_ts or 0)) - - # Now return the first `limit` - return notifs[:limit] - async def get_if_maybe_push_in_range_for_user( self, user_id: str, min_stream_ordering: int ) -> bool: @@ -1860,9 +1769,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND epa.notif = 1 ORDER BY epa.stream_ordering DESC LIMIT ? - """ % ( - before_clause, - ) + """ % (before_clause,) txn.execute(sql, args) return cast( List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall() diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f7acdb859..b7cc0433e7 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -32,8 +32,10 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, Tuple, + TypedDict, cast, ) @@ -41,17 +43,24 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, +) from synapse.api.errors import PartialStateConflictError from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase, relation_from_event +from synapse.events import EventBase, StrippedStateEvent, relation_from_event from synapse.events.snapshot import EventContext +from synapse.events.utils import parse_stripped_state_event from synapse.logging.opentracing import trace from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_tuple_in_list_sql_clause, ) from synapse.storage.databases.main.event_federation import EventFederationStore from synapse.storage.databases.main.events_worker import EventCacheEntry @@ -59,7 +68,15 @@ from synapse.storage.databases.main.search import SearchEntry from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator -from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id +from synapse.types import ( + JsonDict, + MutableStateMap, + StateMap, + StrCollection, + get_domain_from_id, +) +from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES +from synapse.types.state import StateFilter from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically from synapse.util.stringutils import non_null_str_or_none @@ -78,6 +95,19 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) +# State event type/key pairs that we need to gather to fill in the +# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables. +SLIDING_SYNC_RELEVANT_STATE_SET = ( + # So we can fill in the `room_type` column + (EventTypes.Create, ""), + # So we can fill in the `is_encrypted` column + (EventTypes.RoomEncryption, ""), + # So we can fill in the `room_name` column + (EventTypes.Name, ""), + # So we can fill in the `tombstone_successor_room_id` column + (EventTypes.Tombstone, ""), +) + @attr.s(slots=True, auto_attribs=True) class DeltaState: @@ -99,6 +129,80 @@ class DeltaState: return not self.to_delete and not self.to_insert and not self.no_longer_in_room +# We want `total=False` because we want to allow values to be unset. +class SlidingSyncStateInsertValues(TypedDict, total=False): + """ + Insert values relevant for the `sliding_sync_joined_rooms` and + `sliding_sync_membership_snapshots` database tables. + """ + + room_type: Optional[str] + is_encrypted: Optional[bool] + room_name: Optional[str] + tombstone_successor_room_id: Optional[str] + + +class SlidingSyncMembershipSnapshotSharedInsertValues( + SlidingSyncStateInsertValues, total=False +): + """ + Insert values for `sliding_sync_membership_snapshots` that we can share across + multiple memberships + """ + + has_known_state: Optional[bool] + + +@attr.s(slots=True, auto_attribs=True) +class SlidingSyncMembershipInfo: + """ + Values unique to each membership + """ + + user_id: str + sender: str + membership_event_id: str + membership: str + + +@attr.s(slots=True, auto_attribs=True) +class SlidingSyncMembershipInfoWithEventPos(SlidingSyncMembershipInfo): + """ + SlidingSyncMembershipInfo + `stream_ordering`/`instance_name` of the membership + event + """ + + membership_event_stream_ordering: int + membership_event_instance_name: str + + +@attr.s(slots=True, auto_attribs=True) +class SlidingSyncTableChanges: + room_id: str + # If the row doesn't exist in the `sliding_sync_joined_rooms` table, we need to + # fully-insert it which means we also need to include a `bump_stamp` value to use + # for the row. This should only be populated when we're trying to fully-insert a + # row. + # + # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the + # foreground update for + # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by + # https://github.com/element-hq/synapse/issues/17623) + joined_room_bump_stamp_to_fully_insert: Optional[int] + # Values to upsert into `sliding_sync_joined_rooms` + joined_room_updates: SlidingSyncStateInsertValues + + # Shared values to upsert into `sliding_sync_membership_snapshots` for each + # `to_insert_membership_snapshots` + membership_snapshot_shared_insert_values: ( + SlidingSyncMembershipSnapshotSharedInsertValues + ) + # List of membership to insert into `sliding_sync_membership_snapshots` + to_insert_membership_snapshots: List[SlidingSyncMembershipInfo] + # List of user_id to delete from `sliding_sync_membership_snapshots` + to_delete_membership_snapshots: List[str] + + @attr.s(slots=True, auto_attribs=True) class NewEventChainLinks: """Information about new auth chain links that need to be added to the DB. @@ -142,9 +246,9 @@ class PersistEventsStore: self.is_mine_id = hs.is_mine_id # This should only exist on instances that are configured to write - assert ( - hs.get_instance_name() in hs.config.worker.writers.events - ), "Can only instantiate EventsStore on master" + assert hs.get_instance_name() in hs.config.worker.writers.events, ( + "Can only instantiate EventsStore on master" + ) # Since we have been configured to write, we ought to have id generators, # rather than id trackers. @@ -223,9 +327,24 @@ class PersistEventsStore: async with stream_ordering_manager as stream_orderings: for (event, _), stream in zip(events_and_contexts, stream_orderings): + # XXX: We can't rely on `stream_ordering`/`instance_name` being correct + # at this point. We could be working with events that were previously + # persisted as an `outlier` with one `stream_ordering` but are now being + # persisted again and de-outliered and are being assigned a different + # `stream_ordering` here that won't end up being used. + # `_update_outliers_txn()` will fix this discrepancy (always use the + # `stream_ordering` from the first time it was persisted). event.internal_metadata.stream_ordering = stream event.internal_metadata.instance_name = self._instance_name + sliding_sync_table_changes = None + if state_delta_for_room is not None: + sliding_sync_table_changes = ( + await self._calculate_sliding_sync_table_changes( + room_id, events_and_contexts, state_delta_for_room + ) + ) + await self.db_pool.runInteraction( "persist_events", self._persist_events_txn, @@ -235,6 +354,7 @@ class PersistEventsStore: state_delta_for_room=state_delta_for_room, new_forward_extremities=new_forward_extremities, new_event_links=new_event_links, + sliding_sync_table_changes=sliding_sync_table_changes, ) persist_event_counter.inc(len(events_and_contexts)) @@ -261,6 +381,301 @@ class PersistEventsStore: (room_id,), frozenset(new_forward_extremities) ) + async def _calculate_sliding_sync_table_changes( + self, + room_id: str, + events_and_contexts: Sequence[Tuple[EventBase, EventContext]], + delta_state: DeltaState, + ) -> SlidingSyncTableChanges: + """ + Calculate the changes to the `sliding_sync_membership_snapshots` and + `sliding_sync_joined_rooms` tables given the deltas that are going to be used to + update the `current_state_events` table. + + Just a bunch of pre-processing so we so we don't need to spend time in the + transaction itself gathering all of this info. It's also easier to deal with + redactions outside of a transaction. + + Args: + room_id: The room ID currently being processed. + events_and_contexts: List of tuples of (event, context) being persisted. + This is completely optional (you can pass an empty list) and will just + save us from fetching the events from the database if we already have + them. We assume the list is sorted ascending by `stream_ordering`. We + don't care about the sort when the events are backfilled (with negative + `stream_ordering`). + delta_state: Deltas that are going to be used to update the + `current_state_events` table. Changes to the current state of the room. + + Returns: + SlidingSyncTableChanges + """ + to_insert = delta_state.to_insert + to_delete = delta_state.to_delete + + # If no state is changing, we don't need to do anything. This can happen when a + # partial-stated room is re-syncing the current state. + if not to_insert and not to_delete: + return SlidingSyncTableChanges( + room_id=room_id, + joined_room_bump_stamp_to_fully_insert=None, + joined_room_updates={}, + membership_snapshot_shared_insert_values={}, + to_insert_membership_snapshots=[], + to_delete_membership_snapshots=[], + ) + + event_map = {event.event_id: event for event, _ in events_and_contexts} + + # Handle gathering info for the `sliding_sync_membership_snapshots` table + # + # This would only happen if someone was state reset out of the room + user_ids_to_delete_membership_snapshots = [ + state_key + for event_type, state_key in to_delete + if event_type == EventTypes.Member and self.is_mine_id(state_key) + ] + + membership_snapshot_shared_insert_values: SlidingSyncMembershipSnapshotSharedInsertValues = {} + membership_infos_to_insert_membership_snapshots: List[ + SlidingSyncMembershipInfo + ] = [] + if to_insert: + membership_event_id_to_user_id_map: Dict[str, str] = {} + for state_key, event_id in to_insert.items(): + if state_key[0] == EventTypes.Member and self.is_mine_id(state_key[1]): + membership_event_id_to_user_id_map[event_id] = state_key[1] + + membership_event_map: Dict[str, EventBase] = {} + # In normal event persist scenarios, we should be able to find the + # membership events in the `events_and_contexts` given to us but it's + # possible a state reset happened which added us to the room without a + # corresponding new membership event (reset back to a previous membership). + missing_membership_event_ids: Set[str] = set() + for membership_event_id in membership_event_id_to_user_id_map.keys(): + membership_event = event_map.get(membership_event_id) + if membership_event: + membership_event_map[membership_event_id] = membership_event + else: + missing_membership_event_ids.add(membership_event_id) + + # Otherwise, we need to find a couple events that we were reset to. + if missing_membership_event_ids: + remaining_events = await self.store.get_events( + missing_membership_event_ids + ) + # There shouldn't be any missing events + assert remaining_events.keys() == missing_membership_event_ids, ( + missing_membership_event_ids.difference(remaining_events.keys()) + ) + membership_event_map.update(remaining_events) + + for ( + membership_event_id, + user_id, + ) in membership_event_id_to_user_id_map.items(): + membership_infos_to_insert_membership_snapshots.append( + # XXX: We don't use `SlidingSyncMembershipInfoWithEventPos` here + # because we're sourcing the event from `events_and_contexts`, we + # can't rely on `stream_ordering`/`instance_name` being correct at + # this point. We could be working with events that were previously + # persisted as an `outlier` with one `stream_ordering` but are now + # being persisted again and de-outliered and assigned a different + # `stream_ordering` that won't end up being used. Since we call + # `_calculate_sliding_sync_table_changes()` before + # `_update_outliers_txn()` which fixes this discrepancy (always use + # the `stream_ordering` from the first time it was persisted), we're + # working with an unreliable `stream_ordering` value that will + # possibly be unused and not make it into the `events` table. + SlidingSyncMembershipInfo( + user_id=user_id, + sender=membership_event_map[membership_event_id].sender, + membership_event_id=membership_event_id, + membership=membership_event_map[membership_event_id].membership, + ) + ) + + if membership_infos_to_insert_membership_snapshots: + current_state_ids_map: MutableStateMap[str] = dict( + await self.store.get_partial_filtered_current_state_ids( + room_id, + state_filter=StateFilter.from_types( + SLIDING_SYNC_RELEVANT_STATE_SET + ), + ) + ) + # Since we fetched the current state before we took `to_insert`/`to_delete` + # into account, we need to do a couple fixups. + # + # Update the current_state_map with what we have `to_delete` + for state_key in to_delete: + current_state_ids_map.pop(state_key, None) + # Update the current_state_map with what we have `to_insert` + for state_key, event_id in to_insert.items(): + if state_key in SLIDING_SYNC_RELEVANT_STATE_SET: + current_state_ids_map[state_key] = event_id + + current_state_map: MutableStateMap[EventBase] = {} + # In normal event persist scenarios, we probably won't be able to find + # these state events in `events_and_contexts` since we don't generally + # batch up local membership changes with other events, but it can + # happen. + missing_state_event_ids: Set[str] = set() + for state_key, event_id in current_state_ids_map.items(): + event = event_map.get(event_id) + if event: + current_state_map[state_key] = event + else: + missing_state_event_ids.add(event_id) + + # Otherwise, we need to find a couple events + if missing_state_event_ids: + remaining_events = await self.store.get_events( + missing_state_event_ids + ) + # There shouldn't be any missing events + assert remaining_events.keys() == missing_state_event_ids, ( + missing_state_event_ids.difference(remaining_events.keys()) + ) + for event in remaining_events.values(): + current_state_map[(event.type, event.state_key)] = event + + if current_state_map: + state_insert_values = PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + current_state_map + ) + membership_snapshot_shared_insert_values.update(state_insert_values) + # We have current state to work from + membership_snapshot_shared_insert_values["has_known_state"] = True + else: + # We don't have any `current_state_events` anymore (previously + # cleared out because of `no_longer_in_room`). This can happen if + # one user is joined and another is invited (some non-join + # membership). If the joined user leaves, we are `no_longer_in_room` + # and `current_state_events` is cleared out. When the invited user + # rejects the invite (leaves the room), we will end up here. + # + # In these cases, we should inherit the meta data from the previous + # snapshot so we shouldn't update any of the state values. When + # using sliding sync filters, this will prevent the room from + # disappearing/appearing just because you left the room. + # + # Ideally, we could additionally assert that we're only here for + # valid non-join membership transitions. + assert delta_state.no_longer_in_room + + # Handle gathering info for the `sliding_sync_joined_rooms` table + # + # We only deal with + # updating the state related columns. The + # `event_stream_ordering`/`bump_stamp` are updated elsewhere in the event + # persisting stack (see + # `_update_sliding_sync_tables_with_new_persisted_events_txn()`) + # + joined_room_updates: SlidingSyncStateInsertValues = {} + bump_stamp_to_fully_insert: Optional[int] = None + if not delta_state.no_longer_in_room: + current_state_ids_map = {} + + # Always fully-insert rows if they don't already exist in the + # `sliding_sync_joined_rooms` table. This way we can rely on a row if it + # exists in the table. + # + # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the + # foreground update for + # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by + # https://github.com/element-hq/synapse/issues/17623) + existing_row_in_table = await self.store.db_pool.simple_select_one_onecol( + table="sliding_sync_joined_rooms", + keyvalues={"room_id": room_id}, + retcol="room_id", + allow_none=True, + ) + if not existing_row_in_table: + most_recent_bump_event_pos_results = ( + await self.store.get_last_event_pos_in_room( + room_id, + event_types=SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES, + ) + ) + if most_recent_bump_event_pos_results is not None: + _, new_bump_event_pos = most_recent_bump_event_pos_results + + # If we've just joined a remote room, then the last bump event may + # have been backfilled (and so have a negative stream ordering). + # These negative stream orderings can't sensibly be compared, so + # instead just leave it as `None` in the table and we will use their + # membership event position as the bump event position in the + # Sliding Sync API. + if new_bump_event_pos.stream > 0: + bump_stamp_to_fully_insert = new_bump_event_pos.stream + + current_state_ids_map = dict( + await self.store.get_partial_filtered_current_state_ids( + room_id, + state_filter=StateFilter.from_types( + SLIDING_SYNC_RELEVANT_STATE_SET + ), + ) + ) + + # Look through the items we're going to insert into the current state to see + # if there is anything that we care about and should also update in the + # `sliding_sync_joined_rooms` table. + for state_key, event_id in to_insert.items(): + if state_key in SLIDING_SYNC_RELEVANT_STATE_SET: + current_state_ids_map[state_key] = event_id + + # Get the full event objects for the current state events + # + # In normal event persist scenarios, we should be able to find the state + # events in the `events_and_contexts` given to us but it's possible a state + # reset happened which that reset back to a previous state. + current_state_map = {} + missing_event_ids: Set[str] = set() + for state_key, event_id in current_state_ids_map.items(): + event = event_map.get(event_id) + if event: + current_state_map[state_key] = event + else: + missing_event_ids.add(event_id) + + # Otherwise, we need to find a couple events that we were reset to. + if missing_event_ids: + remaining_events = await self.store.get_events(missing_event_ids) + # There shouldn't be any missing events + assert remaining_events.keys() == missing_event_ids, ( + missing_event_ids.difference(remaining_events.keys()) + ) + for event in remaining_events.values(): + current_state_map[(event.type, event.state_key)] = event + + joined_room_updates = ( + PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + current_state_map + ) + ) + + # If something is being deleted from the state, we need to clear it out + for state_key in to_delete: + if state_key == (EventTypes.Create, ""): + joined_room_updates["room_type"] = None + elif state_key == (EventTypes.RoomEncryption, ""): + joined_room_updates["is_encrypted"] = False + elif state_key == (EventTypes.Name, ""): + joined_room_updates["room_name"] = None + + return SlidingSyncTableChanges( + room_id=room_id, + # For `sliding_sync_joined_rooms` + joined_room_bump_stamp_to_fully_insert=bump_stamp_to_fully_insert, + joined_room_updates=joined_room_updates, + # For `sliding_sync_membership_snapshots` + membership_snapshot_shared_insert_values=membership_snapshot_shared_insert_values, + to_insert_membership_snapshots=membership_infos_to_insert_membership_snapshots, + to_delete_membership_snapshots=user_ids_to_delete_membership_snapshots, + ) + async def calculate_chain_cover_index_for_events( self, room_id: str, events: Collection[EventBase] ) -> Dict[str, NewEventChainLinks]: @@ -315,7 +730,7 @@ class PersistEventsStore: keyvalues={}, retcols=("event_id",), ) - already_persisted_events = {event_id for event_id, in rows} + already_persisted_events = {event_id for (event_id,) in rows} state_events = [ event for event in state_events @@ -458,6 +873,7 @@ class PersistEventsStore: state_delta_for_room: Optional[DeltaState], new_forward_extremities: Optional[Set[str]], new_event_links: Dict[str, NewEventChainLinks], + sliding_sync_table_changes: Optional[SlidingSyncTableChanges], ) -> None: """Insert some number of room events into the necessary database tables. @@ -478,9 +894,14 @@ class PersistEventsStore: delete_existing True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. - state_delta_for_room: The current-state delta for the room. + state_delta_for_room: Deltas that are going to be used to update the + `current_state_events` table. Changes to the current state of the room. new_forward_extremities: The new forward extremities for the room: a set of the event ids which are the forward extremities. + sliding_sync_table_changes: Changes to the + `sliding_sync_membership_snapshots` and `sliding_sync_joined_rooms` tables + derived from the given `delta_state` (see + `_calculate_sliding_sync_table_changes(...)`) Raises: PartialStateConflictError: if attempting to persist a partial state event in @@ -590,10 +1011,22 @@ class PersistEventsStore: # room_memberships, where applicable. # NB: This function invalidates all state related caches if state_delta_for_room: + # If the state delta exists, the sliding sync table changes should also exist + assert sliding_sync_table_changes is not None + self._update_current_state_txn( - txn, room_id, state_delta_for_room, min_stream_order + txn, + room_id, + state_delta_for_room, + min_stream_order, + sliding_sync_table_changes, ) + # We only update the sliding sync tables for non-backfilled events. + self._update_sliding_sync_tables_with_new_persisted_events_txn( + txn, room_id, events_and_contexts + ) + def _persist_event_auth_chain_txn( self, txn: LoggingTransaction, @@ -1128,8 +1561,20 @@ class PersistEventsStore: self, room_id: str, state_delta: DeltaState, + sliding_sync_table_changes: SlidingSyncTableChanges, ) -> None: - """Update the current state stored in the datatabase for the given room""" + """ + Update the current state stored in the datatabase for the given room + + Args: + room_id + state_delta: Deltas that are going to be used to update the + `current_state_events` table. Changes to the current state of the room. + sliding_sync_table_changes: Changes to the + `sliding_sync_membership_snapshots` and `sliding_sync_joined_rooms` tables + derived from the given `delta_state` (see + `_calculate_sliding_sync_table_changes(...)`) + """ if state_delta.is_noop(): return @@ -1141,6 +1586,7 @@ class PersistEventsStore: room_id, delta_state=state_delta, stream_id=stream_ordering, + sliding_sync_table_changes=sliding_sync_table_changes, ) def _update_current_state_txn( @@ -1149,16 +1595,40 @@ class PersistEventsStore: room_id: str, delta_state: DeltaState, stream_id: int, + sliding_sync_table_changes: SlidingSyncTableChanges, ) -> None: + """ + Handles updating tables that track the current state of a room. + + Args: + txn + room_id + delta_state: Deltas that are going to be used to update the + `current_state_events` table. Changes to the current state of the room. + stream_id: This is expected to be the minimum `stream_ordering` for the + batch of events that we are persisting; which means we do not end up in a + situation where workers see events before the `current_state_delta` updates. + FIXME: However, this function also gets called with next upcoming + `stream_ordering` when we re-sync the state of a partial stated room (see + `update_current_state(...)`) which may be "correct" but it would be good to + nail down what exactly is the expected value here. + sliding_sync_table_changes: Changes to the + `sliding_sync_membership_snapshots` and `sliding_sync_joined_rooms` tables + derived from the given `delta_state` (see + `_calculate_sliding_sync_table_changes(...)`) + """ to_delete = delta_state.to_delete to_insert = delta_state.to_insert + # Sanity check we're processing the same thing + assert room_id == sliding_sync_table_changes.room_id + # Figure out the changes of membership to invalidate the # `get_rooms_for_user` cache. # We find out which membership events we may have deleted # and which we have added, then we invalidate the caches for all # those users. - members_changed = { + members_to_cache_bust = { state_key for ev_type, state_key in itertools.chain(to_delete, to_insert) if ev_type == EventTypes.Member @@ -1182,16 +1652,22 @@ class PersistEventsStore: """ txn.execute(sql, (stream_id, self._instance_name, room_id)) + # Grab the list of users before we clear out the current state + users_in_room = self.store.get_users_in_room_txn(txn, room_id) # We also want to invalidate the membership caches for users # that were in the room. - users_in_room = self.store.get_users_in_room_txn(txn, room_id) - members_changed.update(users_in_room) + members_to_cache_bust.update(users_in_room) self.db_pool.simple_delete_txn( txn, table="current_state_events", keyvalues={"room_id": room_id}, ) + self.db_pool.simple_delete_txn( + txn, + table="sliding_sync_joined_rooms", + keyvalues={"room_id": room_id}, + ) else: # We're still in the room, so we update the current state as normal. @@ -1216,7 +1692,7 @@ class PersistEventsStore: """ txn.execute_batch( sql, - ( + [ ( stream_id, self._instance_name, @@ -1229,17 +1705,17 @@ class PersistEventsStore: state_key, ) for etype, state_key in itertools.chain(to_delete, to_insert) - ), + ], ) # Now we actually update the current_state_events table txn.execute_batch( "DELETE FROM current_state_events" " WHERE room_id = ? AND type = ? AND state_key = ?", - ( + [ (room_id, etype, state_key) for etype, state_key in itertools.chain(to_delete, to_insert) - ), + ], ) # We include the membership in the current state table, hence we do @@ -1260,6 +1736,63 @@ class PersistEventsStore: ], ) + # Handle updating the `sliding_sync_joined_rooms` table. We only deal with + # updating the state related columns. The + # `event_stream_ordering`/`bump_stamp` are updated elsewhere in the event + # persisting stack (see + # `_update_sliding_sync_tables_with_new_persisted_events_txn()`) + # + # We only need to update when one of the relevant state values has changed + if sliding_sync_table_changes.joined_room_updates: + sliding_sync_updates_keys = ( + sliding_sync_table_changes.joined_room_updates.keys() + ) + sliding_sync_updates_values = ( + sliding_sync_table_changes.joined_room_updates.values() + ) + + args: List[Any] = [ + room_id, + room_id, + sliding_sync_table_changes.joined_room_bump_stamp_to_fully_insert, + ] + args.extend(iter(sliding_sync_updates_values)) + + # XXX: We use a sub-query for `stream_ordering` because it's unreliable to + # pre-calculate from `events_and_contexts` at the time when + # `_calculate_sliding_sync_table_changes()` is ran. We could be working + # with events that were previously persisted as an `outlier` with one + # `stream_ordering` but are now being persisted again and de-outliered + # and assigned a different `stream_ordering`. Since we call + # `_calculate_sliding_sync_table_changes()` before + # `_update_outliers_txn()` which fixes this discrepancy (always use the + # `stream_ordering` from the first time it was persisted), we're working + # with an unreliable `stream_ordering` value that will possibly be + # unused and not make it into the `events` table. + # + # We don't update `event_stream_ordering` `ON CONFLICT` because it's + # simpler and we can just rely on + # `_update_sliding_sync_tables_with_new_persisted_events_txn()` to do + # the right thing (same for `bump_stamp`). The only reason we're + # inserting `event_stream_ordering` here is because the column has a + # `NON NULL` constraint and we need some answer. + txn.execute( + f""" + INSERT INTO sliding_sync_joined_rooms + (room_id, event_stream_ordering, bump_stamp, {", ".join(sliding_sync_updates_keys)}) + VALUES ( + ?, + (SELECT stream_ordering FROM events WHERE room_id = ? ORDER BY stream_ordering DESC LIMIT 1), + ?, + {", ".join("?" for _ in sliding_sync_updates_values)} + ) + ON CONFLICT (room_id) + DO UPDATE SET + {", ".join(f"{key} = EXCLUDED.{key}" for key in sliding_sync_updates_keys)} + """, + args, + ) + # We now update `local_current_membership`. We do this regardless # of whether we're still in the room or not to handle the case where # e.g. we just got banned (where we need to record that fact here). @@ -1272,11 +1805,11 @@ class PersistEventsStore: txn.execute_batch( "DELETE FROM local_current_membership" " WHERE room_id = ? AND user_id = ?", - ( + [ (room_id, state_key) for etype, state_key in itertools.chain(to_delete, to_insert) if etype == EventTypes.Member and self.is_mine_id(state_key) - ), + ], ) if to_insert: @@ -1296,20 +1829,422 @@ class PersistEventsStore: ], ) + # Handle updating the `sliding_sync_membership_snapshots` table + # + # This would only happen if someone was state reset out of the room + if sliding_sync_table_changes.to_delete_membership_snapshots: + self.db_pool.simple_delete_many_txn( + txn, + table="sliding_sync_membership_snapshots", + column="user_id", + values=sliding_sync_table_changes.to_delete_membership_snapshots, + keyvalues={"room_id": room_id}, + ) + + # We do this regardless of whether the server is `no_longer_in_room` or not + # because we still want a row if a local user was just left/kicked or got banned + # from the room. + if sliding_sync_table_changes.to_insert_membership_snapshots: + # Update the `sliding_sync_membership_snapshots` table + # + sliding_sync_snapshot_keys = sliding_sync_table_changes.membership_snapshot_shared_insert_values.keys() + sliding_sync_snapshot_values = sliding_sync_table_changes.membership_snapshot_shared_insert_values.values() + # We need to insert/update regardless of whether we have + # `sliding_sync_snapshot_keys` because there are other fields in the `ON + # CONFLICT` upsert to run (see inherit case (explained in + # `_calculate_sliding_sync_table_changes()`) for more context when this + # happens). + # + # XXX: We use a sub-query for `stream_ordering` because it's unreliable to + # pre-calculate from `events_and_contexts` at the time when + # `_calculate_sliding_sync_table_changes()` is ran. We could be working with + # events that were previously persisted as an `outlier` with one + # `stream_ordering` but are now being persisted again and de-outliered and + # assigned a different `stream_ordering` that won't end up being used. Since + # we call `_calculate_sliding_sync_table_changes()` before + # `_update_outliers_txn()` which fixes this discrepancy (always use the + # `stream_ordering` from the first time it was persisted), we're working + # with an unreliable `stream_ordering` value that will possibly be unused + # and not make it into the `events` table. + txn.execute_batch( + f""" + INSERT INTO sliding_sync_membership_snapshots + (room_id, user_id, sender, membership_event_id, membership, forgotten, event_stream_ordering, event_instance_name + {("," + ", ".join(sliding_sync_snapshot_keys)) if sliding_sync_snapshot_keys else ""}) + VALUES ( + ?, ?, ?, ?, ?, ?, + (SELECT stream_ordering FROM events WHERE event_id = ?), + (SELECT COALESCE(instance_name, 'master') FROM events WHERE event_id = ?) + {("," + ", ".join("?" for _ in sliding_sync_snapshot_values)) if sliding_sync_snapshot_values else ""} + ) + ON CONFLICT (room_id, user_id) + DO UPDATE SET + sender = EXCLUDED.sender, + membership_event_id = EXCLUDED.membership_event_id, + membership = EXCLUDED.membership, + forgotten = EXCLUDED.forgotten, + event_stream_ordering = EXCLUDED.event_stream_ordering + {("," + ", ".join(f"{key} = EXCLUDED.{key}" for key in sliding_sync_snapshot_keys)) if sliding_sync_snapshot_keys else ""} + """, + [ + [ + room_id, + membership_info.user_id, + membership_info.sender, + membership_info.membership_event_id, + membership_info.membership, + # Since this is a new membership, it isn't forgotten anymore (which + # matches how Synapse currently thinks about the forgotten status) + 0, + # XXX: We do not use `membership_info.membership_event_stream_ordering` here + # because it is an unreliable value. See XXX note above. + membership_info.membership_event_id, + # XXX: We do not use `membership_info.membership_event_instance_name` here + # because it is an unreliable value. See XXX note above. + membership_info.membership_event_id, + ] + + list(sliding_sync_snapshot_values) + for membership_info in sliding_sync_table_changes.to_insert_membership_snapshots + ], + ) + txn.call_after( self.store._curr_state_delta_stream_cache.entity_has_changed, room_id, stream_id, ) + for user_id in members_to_cache_bust: + txn.call_after( + self.store._membership_stream_cache.entity_has_changed, + user_id, + stream_id, + ) + # Invalidate the various caches - self.store._invalidate_state_caches_and_stream(txn, room_id, members_changed) + self.store._invalidate_state_caches_and_stream( + txn, room_id, members_to_cache_bust + ) # Check if any of the remote membership changes requires us to # unsubscribe from their device lists. self.store.handle_potentially_left_users_txn( - txn, {m for m in members_changed if not self.hs.is_mine_id(m)} + txn, {m for m in members_to_cache_bust if not self.hs.is_mine_id(m)} + ) + + @classmethod + def _get_relevant_sliding_sync_current_state_event_ids_txn( + cls, txn: LoggingTransaction, room_id: str + ) -> MutableStateMap[str]: + """ + Fetch the current state event IDs for the relevant (to the + `sliding_sync_joined_rooms` table) state types for the given room. + + Returns: + A tuple of: + 1. StateMap of event IDs necessary to to fetch the relevant state values + needed to insert into the + `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots`. + 2. The corresponding latest `stream_id` in the + `current_state_delta_stream` table. This is useful to compare against + the `current_state_delta_stream` table later so you can check whether + the current state has changed since you last fetched the current + state. + """ + # Fetch the current state event IDs from the database + ( + event_type_and_state_key_in_list_clause, + event_type_and_state_key_args, + ) = make_tuple_in_list_sql_clause( + txn.database_engine, + ("type", "state_key"), + SLIDING_SYNC_RELEVANT_STATE_SET, + ) + txn.execute( + f""" + SELECT c.event_id, c.type, c.state_key + FROM current_state_events AS c + WHERE + c.room_id = ? + AND {event_type_and_state_key_in_list_clause} + """, + [room_id] + event_type_and_state_key_args, ) + current_state_map: MutableStateMap[str] = { + (event_type, state_key): event_id for event_id, event_type, state_key in txn + } + + return current_state_map + + @classmethod + def _get_sliding_sync_insert_values_from_state_map( + cls, state_map: StateMap[EventBase] + ) -> SlidingSyncStateInsertValues: + """ + Extract the relevant state values from the `state_map` needed to insert into the + `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables. + + Returns: + Map from column names (`room_type`, `is_encrypted`, `room_name`) to relevant + state values needed to insert into + the `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables. + """ + # Map of values to insert/update in the `sliding_sync_membership_snapshots` table + sliding_sync_insert_map: SlidingSyncStateInsertValues = {} + + # Parse the raw event JSON + for state_key, event in state_map.items(): + if state_key == (EventTypes.Create, ""): + room_type = event.content.get(EventContentFields.ROOM_TYPE) + # Scrutinize JSON values + if room_type is None or ( + isinstance(room_type, str) + # We ignore values with null bytes as Postgres doesn't allow them in + # text columns. + and "\0" not in room_type + ): + sliding_sync_insert_map["room_type"] = room_type + elif state_key == (EventTypes.RoomEncryption, ""): + encryption_algorithm = event.content.get( + EventContentFields.ENCRYPTION_ALGORITHM + ) + is_encrypted = encryption_algorithm is not None + sliding_sync_insert_map["is_encrypted"] = is_encrypted + elif state_key == (EventTypes.Name, ""): + room_name = event.content.get(EventContentFields.ROOM_NAME) + # Scrutinize JSON values. We ignore values with nulls as + # postgres doesn't allow null bytes in text columns. + if room_name is None or ( + isinstance(room_name, str) + # We ignore values with null bytes as Postgres doesn't allow them in + # text columns. + and "\0" not in room_name + ): + sliding_sync_insert_map["room_name"] = room_name + elif state_key == (EventTypes.Tombstone, ""): + successor_room_id = event.content.get( + EventContentFields.TOMBSTONE_SUCCESSOR_ROOM + ) + # Scrutinize JSON values + if successor_room_id is None or ( + isinstance(successor_room_id, str) + # We ignore values with null bytes as Postgres doesn't allow them in + # text columns. + and "\0" not in successor_room_id + ): + sliding_sync_insert_map["tombstone_successor_room_id"] = ( + successor_room_id + ) + else: + # We only expect to see events according to the + # `SLIDING_SYNC_RELEVANT_STATE_SET`. + raise AssertionError( + "Unexpected event (we should not be fetching extra events or this " + + "piece of code needs to be updated to handle a new event type added " + + "to `SLIDING_SYNC_RELEVANT_STATE_SET`): {state_key} {event.event_id}" + ) + + return sliding_sync_insert_map + + @classmethod + def _get_sliding_sync_insert_values_from_stripped_state( + cls, unsigned_stripped_state_events: Any + ) -> SlidingSyncMembershipSnapshotSharedInsertValues: + """ + Pull out the relevant state values from the stripped state on an invite or knock + membership event needed to insert into the `sliding_sync_membership_snapshots` + tables. + + Returns: + Map from column names (`room_type`, `is_encrypted`, `room_name`) to relevant + state values needed to insert into the `sliding_sync_membership_snapshots` tables. + """ + # Map of values to insert/update in the `sliding_sync_membership_snapshots` table + sliding_sync_insert_map: SlidingSyncMembershipSnapshotSharedInsertValues = {} + + if unsigned_stripped_state_events is not None: + stripped_state_map: MutableStateMap[StrippedStateEvent] = {} + if isinstance(unsigned_stripped_state_events, list): + for raw_stripped_event in unsigned_stripped_state_events: + stripped_state_event = parse_stripped_state_event( + raw_stripped_event + ) + if stripped_state_event is not None: + stripped_state_map[ + ( + stripped_state_event.type, + stripped_state_event.state_key, + ) + ] = stripped_state_event + + # If there is some stripped state, we assume the remote server passed *all* + # of the potential stripped state events for the room. + create_stripped_event = stripped_state_map.get((EventTypes.Create, "")) + # Sanity check that we at-least have the create event + if create_stripped_event is not None: + sliding_sync_insert_map["has_known_state"] = True + + # XXX: Keep this up-to-date with `SLIDING_SYNC_RELEVANT_STATE_SET` + + # Find the room_type + sliding_sync_insert_map["room_type"] = ( + create_stripped_event.content.get(EventContentFields.ROOM_TYPE) + if create_stripped_event is not None + else None + ) + + # Find whether the room is_encrypted + encryption_stripped_event = stripped_state_map.get( + (EventTypes.RoomEncryption, "") + ) + encryption = ( + encryption_stripped_event.content.get( + EventContentFields.ENCRYPTION_ALGORITHM + ) + if encryption_stripped_event is not None + else None + ) + sliding_sync_insert_map["is_encrypted"] = encryption is not None + + # Find the room_name + room_name_stripped_event = stripped_state_map.get((EventTypes.Name, "")) + sliding_sync_insert_map["room_name"] = ( + room_name_stripped_event.content.get(EventContentFields.ROOM_NAME) + if room_name_stripped_event is not None + else None + ) + + # Check for null bytes in the room name and type. We have to + # ignore values with null bytes as Postgres doesn't allow them + # in text columns. + if ( + sliding_sync_insert_map["room_name"] is not None + and "\0" in sliding_sync_insert_map["room_name"] + ): + sliding_sync_insert_map.pop("room_name") + + if ( + sliding_sync_insert_map["room_type"] is not None + and "\0" in sliding_sync_insert_map["room_type"] + ): + sliding_sync_insert_map.pop("room_type") + + # Find the tombstone_successor_room_id + # Note: This isn't one of the stripped state events according to the spec + # but seems like there is no reason not to support this kind of thing. + tombstone_stripped_event = stripped_state_map.get( + (EventTypes.Tombstone, "") + ) + sliding_sync_insert_map["tombstone_successor_room_id"] = ( + tombstone_stripped_event.content.get( + EventContentFields.TOMBSTONE_SUCCESSOR_ROOM + ) + if tombstone_stripped_event is not None + else None + ) + + if ( + sliding_sync_insert_map["tombstone_successor_room_id"] is not None + and "\0" in sliding_sync_insert_map["tombstone_successor_room_id"] + ): + sliding_sync_insert_map.pop("tombstone_successor_room_id") + + else: + # No stripped state provided + sliding_sync_insert_map["has_known_state"] = False + sliding_sync_insert_map["room_type"] = None + sliding_sync_insert_map["room_name"] = None + sliding_sync_insert_map["is_encrypted"] = False + else: + # No stripped state provided + sliding_sync_insert_map["has_known_state"] = False + sliding_sync_insert_map["room_type"] = None + sliding_sync_insert_map["room_name"] = None + sliding_sync_insert_map["is_encrypted"] = False + + return sliding_sync_insert_map + + def _update_sliding_sync_tables_with_new_persisted_events_txn( + self, + txn: LoggingTransaction, + room_id: str, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> None: + """ + Update the latest `event_stream_ordering`/`bump_stamp` columns in the + `sliding_sync_joined_rooms` table for the room with new events. + + This function assumes that `_store_event_txn()` (to persist the event) and + `_update_current_state_txn(...)` (so that `sliding_sync_joined_rooms` table has + been updated with rooms that were joined) have already been run. + + Args: + txn + room_id: The room that all of the events belong to + events_and_contexts: The events being persisted. We assume the list is + sorted ascending by `stream_ordering`. We don't care about the sort when the + events are backfilled (with negative `stream_ordering`). + """ + + # Nothing to do if there are no events + if len(events_and_contexts) == 0: + return + + # Since the list is sorted ascending by `stream_ordering`, the last event should + # have the highest `stream_ordering`. + max_stream_ordering = events_and_contexts[-1][ + 0 + ].internal_metadata.stream_ordering + # `stream_ordering` should be assigned for persisted events + assert max_stream_ordering is not None + # Check if the event is a backfilled event (with a negative `stream_ordering`). + # If one event is backfilled, we assume this whole batch was backfilled. + if max_stream_ordering < 0: + # We only update the sliding sync tables for non-backfilled events. + return + + max_bump_stamp = None + for event, _ in reversed(events_and_contexts): + # Sanity check that all events belong to the same room + assert event.room_id == room_id + + if event.type in SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES: + # `stream_ordering` should be assigned for persisted events + assert event.internal_metadata.stream_ordering is not None + + max_bump_stamp = event.internal_metadata.stream_ordering + + # Since we're iterating in reverse, we can break as soon as we find a + # matching bump event which should have the highest `stream_ordering`. + break + + # Handle updating the `sliding_sync_joined_rooms` table. + # + txn.execute( + """ + UPDATE sliding_sync_joined_rooms + SET + event_stream_ordering = CASE + WHEN event_stream_ordering IS NULL OR event_stream_ordering < ? + THEN ? + ELSE event_stream_ordering + END, + bump_stamp = CASE + WHEN bump_stamp IS NULL OR bump_stamp < ? + THEN ? + ELSE bump_stamp + END + WHERE room_id = ? + """, + ( + max_stream_ordering, + max_stream_ordering, + max_bump_stamp, + max_bump_stamp, + room_id, + ), + ) + # This may or may not update any rows depending if we are `no_longer_in_room` def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None: """Update the room version in the database based off current state @@ -1931,7 +2866,9 @@ class PersistEventsStore: ) for event in events: + # Sanity check that we're working with persisted events assert event.internal_metadata.stream_ordering is not None + assert event.internal_metadata.instance_name is not None # We update the local_current_membership table only if the event is # "current", i.e., its something that has just happened. @@ -1945,6 +2882,16 @@ class PersistEventsStore: and event.internal_metadata.is_outlier() and event.internal_metadata.is_out_of_band_membership() ): + # The only sort of out-of-band-membership events we expect to see here + # are remote invites/knocks and LEAVE events corresponding to + # rejected/retracted invites and rescinded knocks. + assert event.type == EventTypes.Member + assert event.membership in ( + Membership.INVITE, + Membership.KNOCK, + Membership.LEAVE, + ) + self.db_pool.simple_upsert_txn( txn, table="local_current_membership", @@ -1956,6 +2903,59 @@ class PersistEventsStore: }, ) + # Handle updating the `sliding_sync_membership_snapshots` table + # (out-of-band membership events only) + # + raw_stripped_state_events = None + if event.membership == Membership.INVITE: + invite_room_state = event.unsigned.get("invite_room_state") + raw_stripped_state_events = invite_room_state + elif event.membership == Membership.KNOCK: + knock_room_state = event.unsigned.get("knock_room_state") + raw_stripped_state_events = knock_room_state + + insert_values = { + "sender": event.sender, + "membership_event_id": event.event_id, + "membership": event.membership, + # Since this is a new membership, it isn't forgotten anymore (which + # matches how Synapse currently thinks about the forgotten status) + "forgotten": 0, + "event_stream_ordering": event.internal_metadata.stream_ordering, + "event_instance_name": event.internal_metadata.instance_name, + } + if event.membership == Membership.LEAVE: + # Inherit the meta data from the remote invite/knock. When using + # sliding sync filters, this will prevent the room from + # disappearing/appearing just because you left the room. + pass + elif event.membership in (Membership.INVITE, Membership.KNOCK): + extra_insert_values = ( + self._get_sliding_sync_insert_values_from_stripped_state( + raw_stripped_state_events + ) + ) + insert_values.update(extra_insert_values) + else: + # We don't know how to handle this type of membership yet + # + # FIXME: We should use `assert_never` here but for some reason + # the exhaustive matching doesn't recognize the `Never` here. + # assert_never(event.membership) + raise AssertionError( + f"Unexpected out-of-band membership {event.membership} ({event.event_id}) that we don't know how to handle yet" + ) + + self.db_pool.simple_upsert_txn( + txn, + table="sliding_sync_membership_snapshots", + keyvalues={ + "room_id": event.room_id, + "user_id": event.state_key, + }, + values=insert_values, + ) + def _handle_event_relations( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -2221,7 +3221,7 @@ class PersistEventsStore: if notifiable_events: txn.execute_batch( sql, - ( + [ ( event.room_id, event.internal_metadata.stream_ordering, @@ -2229,18 +3229,18 @@ class PersistEventsStore: event.event_id, ) for event in notifiable_events - ), + ], ) # Now we delete the staging area for *all* events that were being # persisted. txn.execute_batch( "DELETE FROM event_push_actions_staging WHERE event_id = ?", - ( + [ (event.event_id,) for event, _ in all_events_and_contexts if event.internal_metadata.is_notifiable() - ), + ], ) def _remove_push_actions_for_event_id_txn( @@ -2415,7 +3415,7 @@ class PersistEventsStore: ) potential_backwards_extremities.difference_update( - e for e, in existing_events_outliers + e for (e,) in existing_events_outliers ) if potential_backwards_extremities: @@ -2448,8 +3448,7 @@ class PersistEventsStore: # Delete all these events that we've already fetched and now know that their # prev events are the new backwards extremeties. query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" + "DELETE FROM event_backward_extremities WHERE event_id = ? AND room_id = ?" ) backward_extremity_tuples_to_remove = [ (ev.event_id, ev.room_id) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 64d303e330..5c83a9f779 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -24,9 +24,14 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast import attr -from synapse.api.constants import EventContentFields, RelationTypes +from synapse.api.constants import ( + MAX_DEPTH, + EventContentFields, + Membership, + RelationTypes, +) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -34,9 +39,27 @@ from synapse.storage.database import ( LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.storage.databases.main.events import PersistEventsStore +from synapse.storage.databases.main.events import ( + SLIDING_SYNC_RELEVANT_STATE_SET, + PersistEventsStore, + SlidingSyncMembershipInfoWithEventPos, + SlidingSyncMembershipSnapshotSharedInsertValues, + SlidingSyncStateInsertValues, +) +from synapse.storage.databases.main.events_worker import ( + DatabaseCorruptionError, + InvalidEventError, +) +from synapse.storage.databases.main.state_deltas import StateDeltasStore +from synapse.storage.databases.main.stream import StreamWorkerStore +from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor -from synapse.types import JsonDict, StrCollection +from synapse.types import JsonDict, RoomStreamToken, StateMap, StrCollection +from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES +from synapse.types.state import StateFilter +from synapse.types.storage import _BackgroundUpdates +from synapse.util import json_encoder +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -59,26 +82,6 @@ _REPLACE_STREAM_ORDERING_SQL_COMMANDS = ( ) -class _BackgroundUpdates: - EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" - EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2" - INDEX_STREAM_ORDERING2 = "index_stream_ordering2" - INDEX_STREAM_ORDERING2_CONTAINS_URL = "index_stream_ordering2_contains_url" - INDEX_STREAM_ORDERING2_ROOM_ORDER = "index_stream_ordering2_room_order" - INDEX_STREAM_ORDERING2_ROOM_STREAM = "index_stream_ordering2_room_stream" - INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts" - REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column" - - EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows" - EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index" - - EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections" - - EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index" - - @attr.s(slots=True, frozen=True, auto_attribs=True) class _CalculateChainCover: """Return value for _calculate_chain_cover_txn.""" @@ -97,7 +100,19 @@ class _CalculateChainCover: finished_room_map: Dict[str, Tuple[int, int]] -class EventsBackgroundUpdatesStore(SQLBaseStore): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _JoinedRoomStreamOrderingUpdate: + """ + Intermediate container class used in `SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE` + """ + + # The most recent event stream_ordering for the room + most_recent_event_stream_ordering: int + # The most recent event `bump_stamp` for the room + most_recent_bump_stamp: Optional[int] + + +class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseStore): def __init__( self, database: DatabasePool, @@ -279,6 +294,44 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): where_clause="NOT outlier", ) + # Handle background updates for Sliding Sync tables + # + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE, + self._sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update, + ) + # Add some background updates to populate the sliding sync tables + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE, + self._sliding_sync_joined_rooms_bg_update, + ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE, + self._sliding_sync_membership_snapshots_bg_update, + ) + # Add a background update to fix data integrity issue in the + # `sliding_sync_membership_snapshots` -> `forgotten` column + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_FIX_FORGOTTEN_COLUMN_BG_UPDATE, + self._sliding_sync_membership_snapshots_fix_forgotten_column_bg_update, + ) + + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.FIXUP_MAX_DEPTH_CAP, self.fixup_max_depth_cap_bg_update + ) + + # We want this to run on the main database at startup before we start processing + # events. + # + # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the + # foreground update for + # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by + # https://github.com/element-hq/synapse/issues/17623) + with db_conn.cursor(txn_name="resolve_sliding_sync") as txn: + _resolve_stale_data_in_sliding_sync_tables( + txn=txn, + ) + async def _background_reindex_fields_sender( self, progress: JsonDict, batch_size: int ) -> int: @@ -586,7 +639,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): room_ids = {row[0] for row in rows} for room_id in room_ids: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined] + self.get_latest_event_ids_in_room.invalidate, # type: ignore[attr-defined] + (room_id,), ) self.db_pool.simple_delete_many_txn( @@ -1073,7 +1127,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): PersistEventsStore._add_chain_cover_index( txn, self.db_pool, - self.event_chain_id_gen, # type: ignore[attr-defined] + self.event_chain_id_gen, event_to_room_id, event_to_types, cast(Dict[str, StrCollection], event_to_auth_chain), @@ -1516,3 +1570,1320 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) return batch_size + + async def _sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update( + self, progress: JsonDict, _batch_size: int + ) -> int: + """ + Prefill `sliding_sync_joined_rooms_to_recalculate` table with all rooms we know about already. + """ + + def _txn(txn: LoggingTransaction) -> None: + # We do this as one big bulk insert. This has been tested on a bigger + # homeserver with ~10M rooms and took 60s. There is potential for this to + # starve disk usage while this goes on. + # + # We upsert in case we have to run this multiple times. + txn.execute( + """ + INSERT INTO sliding_sync_joined_rooms_to_recalculate + (room_id) + SELECT DISTINCT room_id FROM local_current_membership + WHERE membership = 'join' + ON CONFLICT (room_id) + DO NOTHING; + """, + ) + + await self.db_pool.runInteraction( + "_sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update", + _txn, + ) + + # Background update is done. + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE + ) + return 0 + + async def _sliding_sync_joined_rooms_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: + """ + Background update to populate the `sliding_sync_joined_rooms` table. + """ + # We don't need to fetch any progress state because we just grab the next N + # events in `sliding_sync_joined_rooms_to_recalculate` + + def _get_rooms_to_update_txn(txn: LoggingTransaction) -> List[Tuple[str]]: + """ + Returns: + A list of room ID's to update along with the progress value + (event_stream_ordering) indicating the continuation point in the + `current_state_events` table for the next batch. + """ + # Fetch the set of room IDs that we want to update + # + # We use `current_state_events` table as the barometer for whether the + # server is still participating in the room because if we're + # `no_longer_in_room`, this table would be cleared out for the given + # `room_id`. + txn.execute( + """ + SELECT room_id + FROM sliding_sync_joined_rooms_to_recalculate + LIMIT ? + """, + (batch_size,), + ) + + rooms_to_update_rows = cast(List[Tuple[str]], txn.fetchall()) + + return rooms_to_update_rows + + rooms_to_update = await self.db_pool.runInteraction( + "_sliding_sync_joined_rooms_bg_update._get_rooms_to_update_txn", + _get_rooms_to_update_txn, + ) + + if not rooms_to_update: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE + ) + return 0 + + # Map from room_id to insert/update state values in the `sliding_sync_joined_rooms` table. + joined_room_updates: Dict[str, SlidingSyncStateInsertValues] = {} + # Map from room_id to stream_ordering/bump_stamp, etc values + joined_room_stream_ordering_updates: Dict[ + str, _JoinedRoomStreamOrderingUpdate + ] = {} + # As long as we get this value before we fetch the current state, we can use it + # to check if something has changed since that point. + most_recent_current_state_delta_stream_id = ( + await self.get_max_stream_id_in_current_state_deltas() + ) + for (room_id,) in rooms_to_update: + current_state_ids_map = await self.db_pool.runInteraction( + "_sliding_sync_joined_rooms_bg_update._get_relevant_sliding_sync_current_state_event_ids_txn", + PersistEventsStore._get_relevant_sliding_sync_current_state_event_ids_txn, + room_id, + ) + + # If we're not joined to the room a) it doesn't belong in the + # `sliding_sync_joined_rooms` table so we should skip and b) we won't have + # any `current_state_events` for the room. + if not current_state_ids_map: + continue + + try: + fetched_events = await self.get_events(current_state_ids_map.values()) + except (DatabaseCorruptionError, InvalidEventError) as e: + logger.warning( + "Failed to fetch state for room '%s' due to corrupted events. Ignoring. Error: %s", + room_id, + e, + ) + continue + + current_state_map: StateMap[EventBase] = { + state_key: fetched_events[event_id] + for state_key, event_id in current_state_ids_map.items() + # `get_events(...)` will filter out events for unknown room versions + if event_id in fetched_events + } + + # Even if we are joined to the room, this can happen for unknown room + # versions (old room versions that aren't known anymore) since + # `get_events(...)` will filter out events for unknown room versions + if not current_state_map: + continue + + state_insert_values = ( + PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + current_state_map + ) + ) + # We should have some insert values for each room, even if they are `None` + assert state_insert_values + joined_room_updates[room_id] = state_insert_values + + # Figure out the stream_ordering of the latest event in the room + most_recent_event_pos_results = await self.get_last_event_pos_in_room( + room_id, event_types=None + ) + assert most_recent_event_pos_results is not None, ( + f"We should not be seeing `None` here because the room ({room_id}) should at-least have a create event " + + "given we pulled the room out of `current_state_events`" + ) + most_recent_event_stream_ordering = most_recent_event_pos_results[1].stream + + # The `most_recent_event_stream_ordering` should be positive, + # however there are (very rare) rooms where that is not the case in + # the matrix.org database. It's not clear how they got into that + # state, but does mean that we cannot assert that the stream + # ordering is indeed positive. + + # Figure out the latest `bump_stamp` in the room. This could be `None` for a + # federated room you just joined where all of events are still `outliers` or + # backfilled history. In the Sliding Sync API, we default to the user's + # membership event `stream_ordering` if we don't have a `bump_stamp` so + # having it as `None` in this table is fine. + bump_stamp_event_pos_results = await self.get_last_event_pos_in_room( + room_id, event_types=SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES + ) + most_recent_bump_stamp = None + if ( + bump_stamp_event_pos_results is not None + and bump_stamp_event_pos_results[1].stream > 0 + ): + most_recent_bump_stamp = bump_stamp_event_pos_results[1].stream + + joined_room_stream_ordering_updates[room_id] = ( + _JoinedRoomStreamOrderingUpdate( + most_recent_event_stream_ordering=most_recent_event_stream_ordering, + most_recent_bump_stamp=most_recent_bump_stamp, + ) + ) + + def _fill_table_txn(txn: LoggingTransaction) -> None: + # Handle updating the `sliding_sync_joined_rooms` table + # + for ( + room_id, + update_map, + ) in joined_room_updates.items(): + joined_room_stream_ordering_update = ( + joined_room_stream_ordering_updates[room_id] + ) + event_stream_ordering = ( + joined_room_stream_ordering_update.most_recent_event_stream_ordering + ) + bump_stamp = joined_room_stream_ordering_update.most_recent_bump_stamp + + # Check if the current state has been updated since we gathered it. + # We're being careful not to insert/overwrite with stale data. + state_deltas_since_we_gathered_current_state = ( + self.get_current_state_deltas_for_room_txn( + txn, + room_id, + from_token=RoomStreamToken( + stream=most_recent_current_state_delta_stream_id + ), + to_token=None, + ) + ) + for state_delta in state_deltas_since_we_gathered_current_state: + # We only need to check for the state is relevant to the + # `sliding_sync_joined_rooms` table. + if ( + state_delta.event_type, + state_delta.state_key, + ) in SLIDING_SYNC_RELEVANT_STATE_SET: + # Raising exception so we can just exit and try again. It would + # be hard to resolve this within the transaction because we need + # to get full events out that take redactions into account. We + # could add some retry logic here, but it's easier to just let + # the background update try again. + raise Exception( + "Current state was updated after we gathered it to update " + + "`sliding_sync_joined_rooms` in the background update. " + + "Raising exception so we can just try again." + ) + + # Since we fully insert rows into `sliding_sync_joined_rooms`, we can + # just do everything on insert and `ON CONFLICT DO NOTHING`. + # + self.db_pool.simple_upsert_txn( + txn, + table="sliding_sync_joined_rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={ + **update_map, + # The reason we're only *inserting* (not *updating*) `event_stream_ordering` + # and `bump_stamp` is because if they are present, that means they are already + # up-to-date. + "event_stream_ordering": event_stream_ordering, + "bump_stamp": bump_stamp, + }, + ) + + # Now that we've processed all the room, we can remove them from the + # queue. + # + # Note: we need to remove all the rooms from the queue we pulled out + # from the DB, not just the ones we've processed above. Otherwise + # we'll simply keep pulling out the same rooms over and over again. + self.db_pool.simple_delete_many_batch_txn( + txn, + table="sliding_sync_joined_rooms_to_recalculate", + keys=("room_id",), + values=rooms_to_update, + ) + + await self.db_pool.runInteraction( + "sliding_sync_joined_rooms_bg_update", _fill_table_txn + ) + + return len(rooms_to_update) + + async def _sliding_sync_membership_snapshots_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: + """ + Background update to populate the `sliding_sync_membership_snapshots` table. + """ + # We do this in two phases: a) the initial phase where we go through all + # room memberships, and then b) a second phase where we look at new + # memberships (this is to handle the case where we downgrade and then + # upgrade again). + # + # We have to do this as two phases (rather than just the second phase + # where we iterate on event_stream_ordering), as the + # `event_stream_ordering` column may have null values for old rows. + # Therefore we first do the set of historic rooms and *then* look at any + # new rows (which will have a non-null `event_stream_ordering`). + initial_phase = progress.get("initial_phase") + if initial_phase is None: + # If this is the first run, store the current max stream position. + # We know we will go through all memberships less than the current + # max in the initial phase. + progress = { + "initial_phase": True, + "last_event_stream_ordering": self.get_room_max_stream_ordering(), + } + await self.db_pool.updates._background_update_progress( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE, + progress, + ) + initial_phase = True + + last_room_id = progress.get("last_room_id", "") + last_user_id = progress.get("last_user_id", "") + last_event_stream_ordering = progress["last_event_stream_ordering"] + + def _find_memberships_to_update_txn( + txn: LoggingTransaction, + ) -> List[ + Tuple[ + str, + Optional[str], + Optional[str], + str, + str, + str, + str, + int, + Optional[str], + bool, + ] + ]: + # Fetch the set of event IDs that we want to update + # + # We skip over rows which we've already handled, i.e. have a + # matching row in `sliding_sync_membership_snapshots` with the same + # room, user and event ID. + # + # We also ignore rooms that the user has left themselves (i.e. not + # kicked). This is to avoid having to port lots of old rooms that we + # will never send down sliding sync (as we exclude such rooms from + # initial syncs). + + if initial_phase: + # There are some old out-of-band memberships (before + # https://github.com/matrix-org/synapse/issues/6983) where we don't have + # the corresponding room stored in the `rooms` table`. We use `LEFT JOIN + # rooms AS r USING (room_id)` to find the rooms missing from `rooms` and + # insert a row for them below. + txn.execute( + """ + SELECT + c.room_id, + r.room_id, + r.room_version, + c.user_id, + e.sender, + c.event_id, + c.membership, + e.stream_ordering, + e.instance_name, + e.outlier + FROM local_current_membership AS c + LEFT JOIN sliding_sync_membership_snapshots AS m USING (room_id, user_id) + INNER JOIN events AS e USING (event_id) + LEFT JOIN rooms AS r ON (c.room_id = r.room_id) + WHERE (c.room_id, c.user_id) > (?, ?) + AND (m.user_id IS NULL OR c.event_id != m.membership_event_id) + ORDER BY c.room_id ASC, c.user_id ASC + LIMIT ? + """, + (last_room_id, last_user_id, batch_size), + ) + elif last_event_stream_ordering is not None: + # It's important to sort by `event_stream_ordering` *ascending* (oldest to + # newest) so that if we see that this background update in progress and want + # to start the catch-up process, we can safely assume that it will + # eventually get to the rooms we want to catch-up on anyway (see + # `_resolve_stale_data_in_sliding_sync_tables()`). + # + # `c.room_id` is duplicated to make it match what we're doing in the + # `initial_phase`. But we can avoid doing the extra `rooms` table join + # because we can assume all of these new events won't have this problem. + txn.execute( + """ + SELECT + c.room_id, + r.room_id, + r.room_version, + c.user_id, + e.sender, + c.event_id, + c.membership, + c.event_stream_ordering, + e.instance_name, + e.outlier + FROM local_current_membership AS c + LEFT JOIN sliding_sync_membership_snapshots AS m USING (room_id, user_id) + INNER JOIN events AS e USING (event_id) + LEFT JOIN rooms AS r ON (c.room_id = r.room_id) + WHERE c.event_stream_ordering > ? + AND (m.user_id IS NULL OR c.event_id != m.membership_event_id) + ORDER BY c.event_stream_ordering ASC + LIMIT ? + """, + (last_event_stream_ordering, batch_size), + ) + else: + raise Exception("last_event_stream_ordering should not be None") + + memberships_to_update_rows = cast( + List[ + Tuple[ + str, + Optional[str], + Optional[str], + str, + str, + str, + str, + int, + Optional[str], + bool, + ] + ], + txn.fetchall(), + ) + + return memberships_to_update_rows + + memberships_to_update_rows = await self.db_pool.runInteraction( + "sliding_sync_membership_snapshots_bg_update._find_memberships_to_update_txn", + _find_memberships_to_update_txn, + ) + + if not memberships_to_update_rows: + if initial_phase: + # Move onto the next phase. + await self.db_pool.updates._background_update_progress( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE, + { + "initial_phase": False, + "last_event_stream_ordering": last_event_stream_ordering, + }, + ) + return 0 + else: + # We've finished both phases, we're done. + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE + ) + return 0 + + def _find_previous_invite_or_knock_membership_txn( + txn: LoggingTransaction, room_id: str, user_id: str, event_id: str + ) -> Optional[Tuple[str, str]]: + # Find the previous invite/knock event before the leave event + # + # Here are some notes on how we landed on this query: + # + # We're using `topological_ordering` instead of `stream_ordering` because + # somehow it's possible to have your `leave` event backfilled with a + # negative `stream_ordering` and your previous `invite` event with a + # positive `stream_ordering` so we wouldn't have a chance of finding the + # previous membership with a naive `event_stream_ordering < ?` comparison. + # + # Also be careful because `room_memberships.event_stream_ordering` is + # nullable and not always filled in. You would need to join on `events` to + # rely on `events.stream_ordering` instead. Even though the + # `events.stream_ordering` also doesn't have a `NOT NULL` constraint, it + # doesn't have any rows where this is the case (checked on `matrix.org`). + # The fact the `events.stream_ordering` is a nullable column is a holdover + # from a rename of the column. + # + # You might also consider using the `event_auth` table to find the previous + # membership, but there are cases where somehow a membership event doesn't + # point back to the previous membership event in the auth events (unknown + # cause). + txn.execute( + """ + SELECT event_id, membership + FROM room_memberships AS m + INNER JOIN events AS e USING (room_id, event_id) + WHERE + room_id = ? + AND m.user_id = ? + AND (m.membership = ? OR m.membership = ?) + AND e.event_id != ? + ORDER BY e.topological_ordering DESC + LIMIT 1 + """, + ( + room_id, + user_id, + # We look explicitly for `invite` and `knock` events instead of + # just their previous membership as someone could have been `invite` + # -> `ban` -> unbanned (`leave`) and we want to find the `invite` + # event where the stripped state is. + Membership.INVITE, + Membership.KNOCK, + event_id, + ), + ) + row = txn.fetchone() + + if row is None: + # Generally we should have an invite or knock event for leaves + # that are outliers, however this may not always be the case + # (e.g. a local user got kicked but the kick event got pulled in + # as an outlier). + return None + + event_id, membership = row + + return event_id, membership + + # Map from (room_id, user_id) to ... + to_insert_membership_snapshots: Dict[ + Tuple[str, str], SlidingSyncMembershipSnapshotSharedInsertValues + ] = {} + to_insert_membership_infos: Dict[ + Tuple[str, str], SlidingSyncMembershipInfoWithEventPos + ] = {} + for ( + room_id, + room_id_from_rooms_table, + room_version_id, + user_id, + sender, + membership_event_id, + membership, + membership_event_stream_ordering, + membership_event_instance_name, + is_outlier, + ) in memberships_to_update_rows: + # We don't know how to handle `membership` values other than these. The + # code below would need to be updated. + assert membership in ( + Membership.JOIN, + Membership.INVITE, + Membership.KNOCK, + Membership.LEAVE, + Membership.BAN, + ) + + if ( + room_version_id is not None + and room_version_id not in KNOWN_ROOM_VERSIONS + ): + # Ignore rooms with unknown room versions (these were + # experimental rooms, that we no longer support). + continue + + # There are some old out-of-band memberships (before + # https://github.com/matrix-org/synapse/issues/6983) where we don't have the + # corresponding room stored in the `rooms` table`. We have a `FOREIGN KEY` + # constraint on the `sliding_sync_membership_snapshots` table so we have to + # fix-up these memberships by adding the room to the `rooms` table. + if room_id_from_rooms_table is None: + await self.db_pool.simple_insert( + table="rooms", + values={ + "room_id": room_id, + # Only out-of-band memberships are missing from the `rooms` + # table so that is the only type of membership we're dealing + # with here. Since we don't calculate the "chain cover" for + # out-of-band memberships, we can just set this to `True` as if + # the user ever joins the room, we will end up calculating the + # "chain cover" anyway. + "has_auth_chain_index": True, + }, + ) + + # Map of values to insert/update in the `sliding_sync_membership_snapshots` table + sliding_sync_membership_snapshots_insert_map: SlidingSyncMembershipSnapshotSharedInsertValues = {} + if membership == Membership.JOIN: + # If we're still joined, we can pull from current state. + current_state_ids_map: StateMap[ + str + ] = await self.hs.get_storage_controllers().state.get_current_state_ids( + room_id, + state_filter=StateFilter.from_types( + SLIDING_SYNC_RELEVANT_STATE_SET + ), + # Partially-stated rooms should have all state events except for + # remote membership events so we don't need to wait at all because + # we only want some non-membership state + await_full_state=False, + ) + # We're iterating over rooms that we are joined to so they should + # have `current_state_events` and we should have some current state + # for each room + if current_state_ids_map: + try: + fetched_events = await self.get_events( + current_state_ids_map.values() + ) + except (DatabaseCorruptionError, InvalidEventError) as e: + logger.warning( + "Failed to fetch state for room '%s' due to corrupted events. Ignoring. Error: %s", + room_id, + e, + ) + continue + + current_state_map: StateMap[EventBase] = { + state_key: fetched_events[event_id] + for state_key, event_id in current_state_ids_map.items() + # `get_events(...)` will filter out events for unknown room versions + if event_id in fetched_events + } + + # Can happen for unknown room versions (old room versions that aren't known + # anymore) since `get_events(...)` will filter out events for unknown room + # versions + if not current_state_map: + continue + + state_insert_values = PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + current_state_map + ) + sliding_sync_membership_snapshots_insert_map.update( + state_insert_values + ) + # We should have some insert values for each room, even if they are `None` + assert sliding_sync_membership_snapshots_insert_map + + # We have current state to work from + sliding_sync_membership_snapshots_insert_map["has_known_state"] = ( + True + ) + else: + # Although we expect every room to have a create event (even + # past unknown room versions since we haven't supported one + # without it), there seem to be some corrupted rooms in + # practice that don't have the create event in the + # `current_state_events` table. The create event does exist + # in the events table though. We'll just say that we don't + # know the state for these rooms and continue on with our + # day. + sliding_sync_membership_snapshots_insert_map = { + "has_known_state": False, + "room_type": None, + "room_name": None, + "is_encrypted": False, + } + elif membership in (Membership.INVITE, Membership.KNOCK) or ( + membership in (Membership.LEAVE, Membership.BAN) and is_outlier + ): + invite_or_knock_event_id = None + invite_or_knock_membership = None + + # If the event is an `out_of_band_membership` (special case of + # `outlier`), we never had historical state so we have to pull from + # the stripped state on the previous invite/knock event. This gives + # us a consistent view of the room state regardless of your + # membership (i.e. the room shouldn't disappear if your using the + # `is_encrypted` filter and you leave). + if membership in (Membership.LEAVE, Membership.BAN) and is_outlier: + previous_membership = await self.db_pool.runInteraction( + "sliding_sync_membership_snapshots_bg_update._find_previous_invite_or_knock_membership_txn", + _find_previous_invite_or_knock_membership_txn, + room_id, + user_id, + membership_event_id, + ) + if previous_membership is not None: + ( + invite_or_knock_event_id, + invite_or_knock_membership, + ) = previous_membership + else: + invite_or_knock_event_id = membership_event_id + invite_or_knock_membership = membership + + if ( + invite_or_knock_event_id is not None + and invite_or_knock_membership is not None + ): + # Pull from the stripped state on the invite/knock event + invite_or_knock_event = await self.get_event( + invite_or_knock_event_id + ) + + raw_stripped_state_events = None + if invite_or_knock_membership == Membership.INVITE: + invite_room_state = invite_or_knock_event.unsigned.get( + "invite_room_state" + ) + raw_stripped_state_events = invite_room_state + elif invite_or_knock_membership == Membership.KNOCK: + knock_room_state = invite_or_knock_event.unsigned.get( + "knock_room_state" + ) + raw_stripped_state_events = knock_room_state + + sliding_sync_membership_snapshots_insert_map = PersistEventsStore._get_sliding_sync_insert_values_from_stripped_state( + raw_stripped_state_events + ) + else: + # We couldn't find any state for the membership, so we just have to + # leave it as empty. + sliding_sync_membership_snapshots_insert_map = { + "has_known_state": False, + "room_type": None, + "room_name": None, + "is_encrypted": False, + } + + # We should have some insert values for each room, even if no + # stripped state is on the event because we still want to record + # that we have no known state + assert sliding_sync_membership_snapshots_insert_map + elif membership in (Membership.LEAVE, Membership.BAN): + # Pull from historical state + state_ids_map = await self.hs.get_storage_controllers().state.get_state_ids_for_event( + membership_event_id, + state_filter=StateFilter.from_types( + SLIDING_SYNC_RELEVANT_STATE_SET + ), + # Partially-stated rooms should have all state events except for + # remote membership events so we don't need to wait at all because + # we only want some non-membership state + await_full_state=False, + ) + + try: + fetched_events = await self.get_events(state_ids_map.values()) + except (DatabaseCorruptionError, InvalidEventError) as e: + logger.warning( + "Failed to fetch state for room '%s' due to corrupted events. Ignoring. Error: %s", + room_id, + e, + ) + continue + + state_map: StateMap[EventBase] = { + state_key: fetched_events[event_id] + for state_key, event_id in state_ids_map.items() + # `get_events(...)` will filter out events for unknown room versions + if event_id in fetched_events + } + + # Can happen for unknown room versions (old room versions that aren't known + # anymore) since `get_events(...)` will filter out events for unknown room + # versions + if not state_map: + continue + + state_insert_values = ( + PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + state_map + ) + ) + sliding_sync_membership_snapshots_insert_map.update(state_insert_values) + # We should have some insert values for each room, even if they are `None` + assert sliding_sync_membership_snapshots_insert_map + + # We have historical state to work from + sliding_sync_membership_snapshots_insert_map["has_known_state"] = True + else: + # We don't know how to handle this type of membership yet + # + # FIXME: We should use `assert_never` here but for some reason + # the exhaustive matching doesn't recognize the `Never` here. + # assert_never(membership) + raise AssertionError( + f"Unexpected membership {membership} ({membership_event_id}) that we don't know how to handle yet" + ) + + to_insert_membership_snapshots[(room_id, user_id)] = ( + sliding_sync_membership_snapshots_insert_map + ) + to_insert_membership_infos[(room_id, user_id)] = ( + SlidingSyncMembershipInfoWithEventPos( + user_id=user_id, + sender=sender, + membership_event_id=membership_event_id, + membership=membership, + membership_event_stream_ordering=membership_event_stream_ordering, + # If instance_name is null we default to "master" + membership_event_instance_name=membership_event_instance_name + or "master", + ) + ) + + def _fill_table_txn(txn: LoggingTransaction) -> None: + # Handle updating the `sliding_sync_membership_snapshots` table + # + for key, insert_map in to_insert_membership_snapshots.items(): + room_id, user_id = key + membership_info = to_insert_membership_infos[key] + sender = membership_info.sender + membership_event_id = membership_info.membership_event_id + membership = membership_info.membership + membership_event_stream_ordering = ( + membership_info.membership_event_stream_ordering + ) + membership_event_instance_name = ( + membership_info.membership_event_instance_name + ) + + # We don't need to upsert the state because we never partially + # insert/update the snapshots and anything already there is up-to-date + # EXCEPT for the `forgotten` field since that is updated out-of-band + # from the membership changes. + # + # Even though we're only doing insertions, we're using + # `simple_upsert_txn()` here to avoid unique violation errors that would + # happen from `simple_insert_txn()` + self.db_pool.simple_upsert_txn( + txn, + table="sliding_sync_membership_snapshots", + keyvalues={"room_id": room_id, "user_id": user_id}, + values={}, + insertion_values={ + **insert_map, + "sender": sender, + "membership_event_id": membership_event_id, + "membership": membership, + "event_stream_ordering": membership_event_stream_ordering, + "event_instance_name": membership_event_instance_name, + }, + ) + # We need to find the `forgotten` value during the transaction because + # we can't risk inserting stale data. + if isinstance(txn.database_engine, PostgresEngine): + txn.execute( + """ + UPDATE sliding_sync_membership_snapshots + SET + forgotten = m.forgotten + FROM room_memberships AS m + WHERE sliding_sync_membership_snapshots.room_id = ? + AND sliding_sync_membership_snapshots.user_id = ? + AND membership_event_id = ? + AND membership_event_id = m.event_id + AND m.event_id IS NOT NULL + """, + ( + room_id, + user_id, + membership_event_id, + ), + ) + else: + # SQLite doesn't support UPDATE FROM before 3.33.0, so we do + # this via sub-selects. + txn.execute( + """ + UPDATE sliding_sync_membership_snapshots + SET + forgotten = (SELECT forgotten FROM room_memberships WHERE event_id = ?) + WHERE room_id = ? and user_id = ? AND membership_event_id = ? + """, + ( + membership_event_id, + room_id, + user_id, + membership_event_id, + ), + ) + + await self.db_pool.runInteraction( + "sliding_sync_membership_snapshots_bg_update", _fill_table_txn + ) + + # Update the progress + ( + room_id, + _room_id_from_rooms_table, + _room_version_id, + user_id, + _sender, + _membership_event_id, + _membership, + membership_event_stream_ordering, + _membership_event_instance_name, + _is_outlier, + ) = memberships_to_update_rows[-1] + + progress = { + "initial_phase": initial_phase, + "last_room_id": room_id, + "last_user_id": user_id, + "last_event_stream_ordering": last_event_stream_ordering, + } + if not initial_phase: + progress["last_event_stream_ordering"] = membership_event_stream_ordering + + await self.db_pool.updates._background_update_progress( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE, + progress, + ) + + return len(memberships_to_update_rows) + + async def _sliding_sync_membership_snapshots_fix_forgotten_column_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: + """ + Background update to update the `sliding_sync_membership_snapshots` -> + `forgotten` column to be in sync with the `room_memberships` table. + + Because of previously flawed code (now fixed); any room that someone has + forgotten and subsequently re-joined or had any new membership on, we need to go + and update the column to match the `room_memberships` table as it has fallen out + of sync. + """ + last_event_stream_ordering = progress.get( + "last_event_stream_ordering", -(1 << 31) + ) + + def _txn( + txn: LoggingTransaction, + ) -> int: + """ + Returns: + The number of rows updated. + """ + + # To simplify things, we can just recheck any row in + # `sliding_sync_membership_snapshots` with `forgotten=1` + txn.execute( + """ + SELECT + s.room_id, + s.user_id, + s.membership_event_id, + s.event_stream_ordering, + m.forgotten + FROM sliding_sync_membership_snapshots AS s + INNER JOIN room_memberships AS m ON (s.membership_event_id = m.event_id) + WHERE s.event_stream_ordering > ? + AND s.forgotten = 1 + ORDER BY s.event_stream_ordering ASC + LIMIT ? + """, + (last_event_stream_ordering, batch_size), + ) + + memberships_to_update_rows = cast( + List[Tuple[str, str, str, int, int]], + txn.fetchall(), + ) + if not memberships_to_update_rows: + return 0 + + # Assemble the values to update + # + # (room_id, user_id) + key_values: List[Tuple[str, str]] = [] + # (forgotten,) + value_values: List[Tuple[int]] = [] + for ( + room_id, + user_id, + _membership_event_id, + _event_stream_ordering, + forgotten, + ) in memberships_to_update_rows: + key_values.append( + ( + room_id, + user_id, + ) + ) + value_values.append((forgotten,)) + + # Update all of the rows in one go + self.db_pool.simple_update_many_txn( + txn, + table="sliding_sync_membership_snapshots", + key_names=("room_id", "user_id"), + key_values=key_values, + value_names=("forgotten",), + value_values=value_values, + ) + + # Update the progress + ( + _room_id, + _user_id, + _membership_event_id, + event_stream_ordering, + _forgotten, + ) = memberships_to_update_rows[-1] + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_FIX_FORGOTTEN_COLUMN_BG_UPDATE, + { + "last_event_stream_ordering": event_stream_ordering, + }, + ) + + return len(memberships_to_update_rows) + + num_rows = await self.db_pool.runInteraction( + "_sliding_sync_membership_snapshots_fix_forgotten_column_bg_update", + _txn, + ) + + if not num_rows: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_FIX_FORGOTTEN_COLUMN_BG_UPDATE + ) + + return num_rows + + async def fixup_max_depth_cap_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: + """Fixes the topological ordering for events that have a depth greater + than MAX_DEPTH. This should fix /messages ordering oddities.""" + + room_id_bound = progress.get("room_id", "") + + def redo_max_depth_bg_update_txn(txn: LoggingTransaction) -> Tuple[bool, int]: + txn.execute( + """ + SELECT room_id, room_version FROM rooms + WHERE room_id > ? + ORDER BY room_id + LIMIT ? + """, + (room_id_bound, batch_size), + ) + + # Find the next room ID to process, with a relevant room version. + room_ids: List[str] = [] + max_room_id: Optional[str] = None + for room_id, room_version_str in txn: + max_room_id = room_id + + # We only want to process rooms with a known room version that + # has strict canonical json validation enabled. + room_version = KNOWN_ROOM_VERSIONS.get(room_version_str) + if room_version and room_version.strict_canonicaljson: + room_ids.append(room_id) + + if max_room_id is None: + # The query did not return any rooms, so we are done. + return True, 0 + + # Update the progress to the last room ID we pulled from the DB, + # this ensures we always make progress. + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.FIXUP_MAX_DEPTH_CAP, + progress={"room_id": max_room_id}, + ) + + if not room_ids: + # There were no rooms in this batch that required the fix. + return False, 0 + + clause, list_args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + sql = f""" + UPDATE events SET topological_ordering = ? + WHERE topological_ordering > ? AND {clause} + """ + args = [MAX_DEPTH, MAX_DEPTH] + args.extend(list_args) + txn.execute(sql, args) + + return False, len(room_ids) + + done, num_rooms = await self.db_pool.runInteraction( + "redo_max_depth_bg_update", redo_max_depth_bg_update_txn + ) + + if done: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.FIXUP_MAX_DEPTH_CAP + ) + + return num_rooms + + +def _resolve_stale_data_in_sliding_sync_tables( + txn: LoggingTransaction, +) -> None: + """ + Clears stale/out-of-date entries from the + `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables. + + This accounts for when someone downgrades their Synapse version and then upgrades it + again. This will ensure that we don't have any stale/out-of-date data in the + `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables since any new + events sent in rooms would have also needed to be written to the sliding sync + tables. For example a new event needs to bump `event_stream_ordering` in + `sliding_sync_joined_rooms` table or some state in the room changing (like the room + name). Or another example of someone's membership changing in a room affecting + `sliding_sync_membership_snapshots`. + + This way, if a row exists in the sliding sync tables, we are able to rely on it + (accurate data). And if a row doesn't exist, we use a fallback to get the same info + until the background updates fill in the rows or a new event comes in triggering it + to be fully inserted. + + FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the + foreground update for + `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by + https://github.com/element-hq/synapse/issues/17623) + """ + + _resolve_stale_data_in_sliding_sync_joined_rooms_table(txn) + _resolve_stale_data_in_sliding_sync_membership_snapshots_table(txn) + + +def _resolve_stale_data_in_sliding_sync_joined_rooms_table( + txn: LoggingTransaction, +) -> None: + """ + Clears stale/out-of-date entries from the `sliding_sync_joined_rooms` table and + kicks-off the background update to catch-up with what we missed while Synapse was + downgraded. + + See `_resolve_stale_data_in_sliding_sync_tables()` description above for more + context. + """ + + # Find the point when we stopped writing to the `sliding_sync_joined_rooms` table + txn.execute( + """ + SELECT event_stream_ordering + FROM sliding_sync_joined_rooms + ORDER BY event_stream_ordering DESC + LIMIT 1 + """, + ) + + # If we have nothing written to the `sliding_sync_joined_rooms` table, there is + # nothing to clean up + row = cast(Optional[Tuple[int]], txn.fetchone()) + max_stream_ordering_sliding_sync_joined_rooms_table = None + depends_on = None + if row is not None: + (max_stream_ordering_sliding_sync_joined_rooms_table,) = row + + txn.execute( + """ + SELECT room_id + FROM events + WHERE stream_ordering > ? + GROUP BY room_id + ORDER BY MAX(stream_ordering) ASC + """, + (max_stream_ordering_sliding_sync_joined_rooms_table,), + ) + + room_rows = txn.fetchall() + # No new events have been written to the `events` table since the last time we wrote + # to the `sliding_sync_joined_rooms` table so there is nothing to clean up. This is + # the expected normal scenario for people who have not downgraded their Synapse + # version. + if not room_rows: + return + + # 1000 is an arbitrary batch size with no testing + for chunk in batch_iter(room_rows, 1000): + # Handle updating the `sliding_sync_joined_rooms` table + # + # Clear out the stale data + DatabasePool.simple_delete_many_batch_txn( + txn, + table="sliding_sync_joined_rooms", + keys=("room_id",), + values=chunk, + ) + + # Update the `sliding_sync_joined_rooms_to_recalculate` table with the rooms + # that went stale and now need to be recalculated. + DatabasePool.simple_upsert_many_txn_native_upsert( + txn, + table="sliding_sync_joined_rooms_to_recalculate", + key_names=("room_id",), + key_values=chunk, + value_names=(), + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + value_values=[() for x in range(len(chunk))], + ) + else: + # Avoid adding the background updates when there is no data to run them on (if + # the homeserver has no rooms). The portdb script refuses to run with pending + # background updates and since we potentially add them every time the server + # starts, we add this check for to allow the script to breath. + txn.execute("SELECT 1 FROM local_current_membership LIMIT 1") + row = txn.fetchone() + if row is None: + # There are no rooms, so don't schedule the bg update. + return + + # Re-run the `sliding_sync_joined_rooms_to_recalculate` prefill if there is + # nothing in the `sliding_sync_joined_rooms` table + DatabasePool.simple_upsert_txn_native_upsert( + txn, + table="background_updates", + keyvalues={ + "update_name": _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE + }, + values={}, + # Only insert the row if it doesn't already exist. If it already exists, + # we're already working on it + insertion_values={ + "progress_json": "{}", + }, + ) + depends_on = _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE + + # Now kick-off the background update to catch-up with what we missed while Synapse + # was downgraded. + # + # We may need to catch-up on everything if we have nothing written to the + # `sliding_sync_joined_rooms` table yet. This could happen if someone had zero rooms + # on their server (so the normal background update completes), downgrade Synapse + # versions, join and create some new rooms, and upgrade again. + DatabasePool.simple_upsert_txn_native_upsert( + txn, + table="background_updates", + keyvalues={ + "update_name": _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE + }, + values={}, + # Only insert the row if it doesn't already exist. If it already exists, we will + # eventually fill in the rows we're trying to populate. + insertion_values={ + # Empty progress is expected since it's not used for this background update. + "progress_json": "{}", + # Wait for the prefill to finish + "depends_on": depends_on, + }, + ) + + +def _resolve_stale_data_in_sliding_sync_membership_snapshots_table( + txn: LoggingTransaction, +) -> None: + """ + Clears stale/out-of-date entries from the `sliding_sync_membership_snapshots` table + and kicks-off the background update to catch-up with what we missed while Synapse + was downgraded. + + See `_resolve_stale_data_in_sliding_sync_tables()` description above for more + context. + """ + + # Find the point when we stopped writing to the `sliding_sync_membership_snapshots` table + txn.execute( + """ + SELECT event_stream_ordering + FROM sliding_sync_membership_snapshots + ORDER BY event_stream_ordering DESC + LIMIT 1 + """, + ) + + # If we have nothing written to the `sliding_sync_membership_snapshots` table, + # there is nothing to clean up + row = cast(Optional[Tuple[int]], txn.fetchone()) + max_stream_ordering_sliding_sync_membership_snapshots_table = None + if row is not None: + (max_stream_ordering_sliding_sync_membership_snapshots_table,) = row + + # XXX: Since `forgotten` is simply a flag on the `room_memberships` table that is + # set out-of-band, there is no way to tell whether it was set while Synapse was + # downgraded. The only thing the user can do is `/forget` again if they run into + # this. + # + # This only picks up changes to memberships. + txn.execute( + """ + SELECT user_id, room_id + FROM local_current_membership + WHERE event_stream_ordering > ? + ORDER BY event_stream_ordering ASC + """, + (max_stream_ordering_sliding_sync_membership_snapshots_table,), + ) + + membership_rows = txn.fetchall() + # No new events have been written to the `events` table since the last time we wrote + # to the `sliding_sync_membership_snapshots` table so there is nothing to clean up. + # This is the expected normal scenario for people who have not downgraded their + # Synapse version. + if not membership_rows: + return + + # 1000 is an arbitrary batch size with no testing + for chunk in batch_iter(membership_rows, 1000): + # Handle updating the `sliding_sync_membership_snapshots` table + # + DatabasePool.simple_delete_many_batch_txn( + txn, + table="sliding_sync_membership_snapshots", + keys=("user_id", "room_id"), + values=chunk, + ) + else: + # Avoid adding the background updates when there is no data to run them on (if + # the homeserver has no rooms). The portdb script refuses to run with pending + # background updates and since we potentially add them every time the server + # starts, we add this check for to allow the script to breath. + txn.execute("SELECT 1 FROM local_current_membership LIMIT 1") + row = txn.fetchone() + if row is None: + # There are no rooms, so don't schedule the bg update. + return + + # Now kick-off the background update to catch-up with what we missed while Synapse + # was downgraded. + # + # We may need to catch-up on everything if we have nothing written to the + # `sliding_sync_membership_snapshots` table yet. This could happen if someone had + # zero rooms on their server (so the normal background update completes), downgrade + # Synapse versions, join and create some new rooms, and upgrade again. + # + progress_json: JsonDict = {} + if max_stream_ordering_sliding_sync_membership_snapshots_table is not None: + progress_json["initial_phase"] = False + progress_json["last_event_stream_ordering"] = ( + max_stream_ordering_sliding_sync_membership_snapshots_table + ) + + DatabasePool.simple_upsert_txn_native_upsert( + txn, + table="background_updates", + keyvalues={ + "update_name": _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE + }, + values={}, + # Only insert the row if it doesn't already exist. If it already exists, we will + # eventually fill in the rows we're trying to populate. + insertion_values={ + "progress_json": json_encoder.encode(progress_json), + }, + ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4d4877c4c3..3db4460f57 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -30,6 +30,7 @@ from typing import ( Dict, Iterable, List, + Literal, Mapping, MutableMapping, Optional, @@ -41,7 +42,6 @@ from typing import ( import attr from prometheus_client import Gauge -from typing_extensions import Literal from twisted.internet import defer @@ -61,7 +61,13 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) -from synapse.logging.opentracing import start_active_span, tag_args, trace +from synapse.logging.opentracing import ( + SynapseTags, + set_tag, + start_active_span, + tag_args, + trace, +) from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -83,6 +89,7 @@ from synapse.storage.util.id_generators import ( from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id from synapse.types.state import StateFilter +from synapse.types.storage import _BackgroundUpdates from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList @@ -98,6 +105,26 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class DatabaseCorruptionError(RuntimeError): + """We found an event in the DB that has a persisted event ID that doesn't + match its computed event ID.""" + + def __init__( + self, room_id: str, persisted_event_id: str, computed_event_id: str + ) -> None: + self.room_id = room_id + self.persisted_event_id = persisted_event_id + self.computed_event_id = computed_event_id + + message = ( + f"Database corruption: Event {persisted_event_id} in room {room_id} " + f"from the database appears to have been modified (calculated " + f"event id {computed_event_id})" + ) + + super().__init__(message) + + # These values are used in the `enqueue_event` and `_fetch_loop` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster @@ -166,6 +193,14 @@ class _EventRow: outlier: bool +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventMetadata: + """Event metadata returned by `get_metadata_for_event(..)`""" + + sender: str + received_ts: int + + class EventRedactBehaviour(Enum): """ What to do when retrieving a redacted event from the database. @@ -304,6 +339,16 @@ class EventsWorkerStore(SQLBaseStore): writers=["master"], ) + # Added to accommodate some queries for the admin API in order to fetch/filter + # membership events by when it was received + self.db_pool.updates.register_background_index_update( + update_name="events_received_ts_index", + index_name="received_ts_idx", + table="events", + columns=("received_ts",), + where_clause="type = 'm.room.member'", + ) + def get_un_partial_stated_events_token(self, instance_name: str) -> int: return ( self._un_partial_stated_events_stream_id_gen.get_current_token_for_writer( @@ -457,6 +502,8 @@ class EventsWorkerStore(SQLBaseStore): ) -> Optional[EventBase]: """Get an event from the database by event_id. + Events for unknown room versions will also be filtered out. + Args: event_id: The event_id of the event to fetch @@ -502,6 +549,7 @@ class EventsWorkerStore(SQLBaseStore): return event + @trace async def get_events( self, event_ids: Collection[str], @@ -511,6 +559,10 @@ class EventsWorkerStore(SQLBaseStore): ) -> Dict[str, EventBase]: """Get events from the database + Unknown events will be omitted from the response. + + Events for unknown room versions will also be filtered out. + Args: event_ids: The event_ids of the events to fetch @@ -529,6 +581,11 @@ class EventsWorkerStore(SQLBaseStore): Returns: A mapping from event_id to event. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, @@ -553,6 +610,8 @@ class EventsWorkerStore(SQLBaseStore): Unknown events will be omitted from the response. + Events for unknown room versions will also be filtered out. + Args: event_ids: The event_ids of the events to fetch @@ -574,6 +633,10 @@ class EventsWorkerStore(SQLBaseStore): Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) if not event_ids: return [] @@ -694,10 +757,11 @@ class EventsWorkerStore(SQLBaseStore): return events + @trace @cancellable async def get_unredacted_events_from_cache_or_db( self, - event_ids: Iterable[str], + event_ids: Collection[str], allow_rejected: bool = False, ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. @@ -719,6 +783,11 @@ class EventsWorkerStore(SQLBaseStore): Returns: map from event id to result """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + # Shortcut: check if we have any events in the *in memory* cache - this function # may be called repeatedly for the same event so at this point we cannot reach # out to any external cache for performance reasons. The external cache is @@ -755,9 +824,9 @@ class EventsWorkerStore(SQLBaseStore): if missing_events_ids: - async def get_missing_events_from_cache_or_db() -> ( - Dict[str, EventCacheEntry] - ): + async def get_missing_events_from_cache_or_db() -> Dict[ + str, EventCacheEntry + ]: """Fetches the events in `missing_event_ids` from the database. Also creates entries in `self._current_event_fetches` to allow @@ -907,7 +976,7 @@ class EventsWorkerStore(SQLBaseStore): events, update_metrics=update_metrics ) - missing_event_ids = (e for e in events if e not in event_map) + missing_event_ids = [e for e in events if e not in event_map] event_map.update( await self._get_events_from_external_cache( events=missing_event_ids, @@ -917,8 +986,9 @@ class EventsWorkerStore(SQLBaseStore): return event_map + @trace async def _get_events_from_external_cache( - self, events: Iterable[str], update_metrics: bool = True + self, events: Collection[str], update_metrics: bool = True ) -> Dict[str, EventCacheEntry]: """Fetch events from any configured external cache. @@ -928,6 +998,10 @@ class EventsWorkerStore(SQLBaseStore): events: list of event_ids to fetch update_metrics: Whether to update the cache hit ratio metrics """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "events.length", + str(len(events)), + ) event_map = {} for event_id in events: @@ -1193,6 +1267,7 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire_errback, e) + @trace async def _get_events_from_db( self, event_ids: Collection[str] ) -> Dict[str, EventCacheEntry]: @@ -1211,6 +1286,11 @@ class EventsWorkerStore(SQLBaseStore): map from event id to result. May return extra events which weren't asked for. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + fetched_event_ids: Set[str] = set() fetched_events: Dict[str, _EventRow] = {} @@ -1356,10 +1436,8 @@ class EventsWorkerStore(SQLBaseStore): if original_ev.event_id != event_id: # it's difficult to see what to do here. Pretty much all bets are off # if Synapse cannot rely on the consistency of its database. - raise RuntimeError( - f"Database corruption: Event {event_id} in room {d['room_id']} " - f"from the database appears to have been modified (calculated " - f"event id {original_ev.event_id})" + raise DatabaseCorruptionError( + d["room_id"], event_id, original_ev.event_id ) event_map[event_id] = original_ev @@ -1639,7 +1717,7 @@ class EventsWorkerStore(SQLBaseStore): txn.database_engine, "e.event_id", event_ids ) txn.execute(sql + clause, args) - found_events = {eid for eid, in txn} + found_events = {eid for (eid,) in txn} # ... and then we can update the results for each key return {eid: (eid in found_events) for eid in event_ids} @@ -1838,9 +1916,9 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates: List[Tuple[int, Tuple[str, str, str, str, str, str]]] = ( - [] - ) + new_event_updates: List[ + Tuple[int, Tuple[str, str, str, str, str, str]] + ] = [] row: Tuple[int, str, str, str, str, str, str] # Type safety: iterating over `txn` yields `Tuple`, i.e. # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a @@ -2439,3 +2517,141 @@ class EventsWorkerStore(SQLBaseStore): ) self.invalidate_get_event_cache_after_txn(txn, event_id) + + async def get_events_sent_by_user_in_room( + self, user_id: str, room_id: str, limit: int, filter: Optional[List[str]] = None + ) -> Optional[List[str]]: + """ + Get a list of event ids of events sent by the user in the specified room + + Args: + user_id: user ID to search against + room_id: room ID of the room to search for events in + filter: type of events to filter for + limit: maximum number of event ids to return + """ + + def _get_events_by_user_in_room_txn( + txn: LoggingTransaction, + user_id: str, + room_id: str, + filter: Optional[List[str]], + batch_size: int, + offset: int, + ) -> Tuple[Optional[List[str]], int]: + if filter: + base_clause, args = make_in_list_sql_clause( + txn.database_engine, "type", filter + ) + clause = f"AND {base_clause}" + parameters = (user_id, room_id, *args, batch_size, offset) + else: + clause = "" + parameters = (user_id, room_id, batch_size, offset) + + sql = f""" + SELECT event_id FROM events + WHERE sender = ? AND room_id = ? + {clause} + ORDER BY received_ts DESC + LIMIT ? + OFFSET ? + """ + txn.execute(sql, parameters) + res = txn.fetchall() + if res: + events = [row[0] for row in res] + else: + events = None + + return events, offset + batch_size + + offset = 0 + batch_size = 100 + if batch_size > limit: + batch_size = limit + + selected_ids: List[str] = [] + while offset < limit: + res, offset = await self.db_pool.runInteraction( + "get_events_by_user", + _get_events_by_user_in_room_txn, + user_id, + room_id, + filter, + batch_size, + offset, + ) + if res: + selected_ids = selected_ids + res + else: + break + return selected_ids + + async def have_finished_sliding_sync_background_jobs(self) -> bool: + """Return if it's safe to use the sliding sync membership tables.""" + + return await self.db_pool.updates.have_completed_background_updates( + ( + _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE, + _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE, + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE, + ) + ) + + async def get_sent_invite_count_by_user(self, user_id: str, from_ts: int) -> int: + """ + Get the number of invites sent by the given user at or after the provided timestamp. + + Args: + user_id: user ID to search against + from_ts: a timestamp in milliseconds from the unix epoch. Filters against + `events.received_ts` + + """ + + def _get_sent_invite_count_by_user_txn( + txn: LoggingTransaction, user_id: str, from_ts: int + ) -> int: + sql = """ + SELECT COUNT(rm.event_id) + FROM room_memberships AS rm + INNER JOIN events AS e USING(event_id) + WHERE rm.sender = ? + AND rm.membership = 'invite' + AND e.type = 'm.room.member' + AND e.received_ts >= ? + """ + + txn.execute(sql, (user_id, from_ts)) + res = txn.fetchone() + + if res is None: + return 0 + return int(res[0]) + + return await self.db_pool.runInteraction( + "_get_sent_invite_count_by_user_txn", + _get_sent_invite_count_by_user_txn, + user_id, + from_ts, + ) + + @cached(tree=True) + async def get_metadata_for_event( + self, room_id: str, event_id: str + ) -> Optional[EventMetadata]: + row = await self.db_pool.simple_select_one( + table="events", + keyvalues={"room_id": room_id, "event_id": event_id}, + retcols=("sender", "received_ts"), + allow_none=True, + desc="get_metadata_for_event", + ) + if row is None: + return None + + return EventMetadata( + sender=row[0], + received_ts=row[1], + ) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 7617fd3ad4..04866524e3 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -19,6 +19,7 @@ # [This file includes modifications made by New Vector Limited] # # +import logging from enum import Enum from typing import ( TYPE_CHECKING, @@ -51,6 +52,8 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( "media_repository_drop_index_wo_method_2" ) +logger = logging.getLogger(__name__) + @attr.s(slots=True, frozen=True, auto_attribs=True) class LocalMedia: @@ -65,6 +68,7 @@ class LocalMedia: safe_from_quarantine: bool user_id: Optional[str] authenticated: Optional[bool] + sha256: Optional[str] @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -79,6 +83,7 @@ class RemoteMedia: last_access_ts: int quarantined_by: Optional[str] authenticated: Optional[bool] + sha256: Optional[str] @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -154,6 +159,26 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): unique=True, ) + self.db_pool.updates.register_background_index_update( + update_name="local_media_repository_sha256_idx", + index_name="local_media_repository_sha256", + table="local_media_repository", + where_clause="sha256 IS NOT NULL", + columns=[ + "sha256", + ], + ) + + self.db_pool.updates.register_background_index_update( + update_name="remote_media_cache_sha256_idx", + index_name="remote_media_cache_sha256", + table="remote_media_cache", + where_clause="sha256 IS NOT NULL", + columns=[ + "sha256", + ], + ) + self.db_pool.updates.register_background_update_handler( BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2, self._drop_media_index_without_method, @@ -221,6 +246,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "safe_from_quarantine", "user_id", "authenticated", + "sha256", ), allow_none=True, desc="get_local_media", @@ -239,6 +265,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): safe_from_quarantine=row[7], user_id=row[8], authenticated=row[9], + sha256=row[10], ) async def get_local_media_by_user_paginate( @@ -295,7 +322,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): quarantined_by, safe_from_quarantine, user_id, - authenticated + authenticated, + sha256 FROM local_media_repository WHERE user_id = ? ORDER BY {order_by_column} {order}, media_id ASC @@ -320,6 +348,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): safe_from_quarantine=bool(row[8]), user_id=row[9], authenticated=row[10], + sha256=row[11], ) for row in txn ] @@ -449,6 +478,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length: int, user_id: UserID, url_cache: Optional[str] = None, + sha256: Optional[str] = None, + quarantined_by: Optional[str] = None, ) -> None: if self.hs.config.media.enable_authenticated_media: authenticated = True @@ -466,6 +497,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "user_id": user_id.to_string(), "url_cache": url_cache, "authenticated": authenticated, + "sha256": sha256, + "quarantined_by": quarantined_by, }, desc="store_local_media", ) @@ -477,20 +510,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name: Optional[str], media_length: int, user_id: UserID, + sha256: str, url_cache: Optional[str] = None, + quarantined_by: Optional[str] = None, ) -> None: + updatevalues = { + "media_type": media_type, + "upload_name": upload_name, + "media_length": media_length, + "url_cache": url_cache, + "sha256": sha256, + } + + # This should never be un-set by this function. + if quarantined_by is not None: + updatevalues["quarantined_by"] = quarantined_by + await self.db_pool.simple_update_one( "local_media_repository", keyvalues={ - "user_id": user_id.to_string(), "media_id": media_id, }, - updatevalues={ - "media_type": media_type, - "upload_name": upload_name, - "media_length": media_length, - "url_cache": url_cache, - }, + updatevalues=updatevalues, desc="update_local_media", ) @@ -657,6 +698,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "last_access_ts", "quarantined_by", "authenticated", + "sha256", ), allow_none=True, desc="get_cached_remote_media", @@ -674,6 +716,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): last_access_ts=row[5], quarantined_by=row[6], authenticated=row[7], + sha256=row[8], ) async def store_cached_remote_media( @@ -685,6 +728,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): time_now_ms: int, upload_name: Optional[str], filesystem_id: str, + sha256: Optional[str], ) -> None: if self.hs.config.media.enable_authenticated_media: authenticated = True @@ -703,6 +747,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "filesystem_id": filesystem_id, "last_access_ts": time_now_ms, "authenticated": authenticated, + "sha256": sha256, }, desc="store_cached_remote_media", ) @@ -729,10 +774,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute_batch( sql, - ( + [ (time_ms, media_origin, media_id) for media_origin, media_id in remote_media - ), + ], ) sql = ( @@ -740,7 +785,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE media_id = ?" ) - txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) + txn.execute_batch(sql, [(time_ms, media_id) for media_id in local_media]) await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn @@ -946,3 +991,46 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): await self.db_pool.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) + + async def get_is_hash_quarantined(self, sha256: str) -> bool: + """Get whether a specific sha256 hash digest matches any quarantined media. + + Returns: + None if the media_id doesn't exist. + """ + + # If we don't have the index yet, performance tanks, so we return False. + # In the background updates, remote_media_cache_sha256_idx is created + # after local_media_repository_sha256_idx, which is why we only need to + # check for the completion of the former. + if not await self.db_pool.updates.has_completed_background_update( + "remote_media_cache_sha256_idx" + ): + return False + + def get_matching_media_txn( + txn: LoggingTransaction, table: str, sha256: str + ) -> bool: + # Return on first match + sql = """ + SELECT 1 + FROM local_media_repository + WHERE sha256 = ? AND quarantined_by IS NOT NULL + + UNION ALL + + SELECT 1 + FROM remote_media_cache + WHERE sha256 = ? AND quarantined_by IS NOT NULL + LIMIT 1 + """ + txn.execute(sql, (sha256, sha256)) + row = txn.fetchone() + return row is not None + + return await self.db_pool.runInteraction( + "get_matching_media_txn", + get_matching_media_txn, + "local_media_repository", + sha256, + ) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 8e948c5e8d..c384675839 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -29,7 +29,6 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.registration import RegistrationWorkerStore from synapse.util.caches.descriptors import cached -from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -65,18 +64,6 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): self._mau_stats_only = hs.config.server.mau_stats_only - if self._update_on_this_worker: - # Do not add more reserved users than the total allowable number - self.db_pool.new_transaction( - db_conn, - "initialise_mau_threepids", - [], - [], - [], - self._initialise_reserved_users, - hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], - ) - @cached(num_args=0) async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users @@ -174,26 +161,6 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): return await self.db_pool.runInteraction("list_users", _list_users) - async def get_registered_reserved_users(self) -> List[str]: - """Of the reserved threepids defined in config, retrieve those that are associated - with registered users - - Returns: - User IDs of actual users that are reserved - """ - users = [] - - for tp in self.hs.config.server.mau_limits_reserved_threepids[ - : self.hs.config.server.max_mau_value - ]: - user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( - tp["medium"], canonicalise_email(tp["address"]) - ) - if user_id: - users.append(user_id) - - return users - @cached(num_args=1) async def user_last_seen_monthly_active(self, user_id: str) -> Optional[int]: """ @@ -289,50 +256,10 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): ) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) - reserved_users = await self.get_registered_reserved_users() await self.db_pool.runInteraction( - "reap_monthly_active_users", _reap_users, reserved_users + "reap_monthly_active_users", _reap_users, [] ) - def _initialise_reserved_users( - self, txn: LoggingTransaction, threepids: List[dict] - ) -> None: - """Ensures that reserved threepids are accounted for in the MAU table, should - be called on start up. - - Args: - txn: - threepids: List of threepid dicts to reserve - """ - assert ( - self._update_on_this_worker - ), "This worker is not designated to update MAUs" - - # XXX what is this function trying to achieve? It upserts into - # monthly_active_users for each *registered* reserved mau user, but why? - # - # - shouldn't there already be an entry for each reserved user (at least - # if they have been active recently)? - # - # - if it's important that the timestamp is kept up to date, why do we only - # run this at startup? - - for tp in threepids: - user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) - - if user_id: - is_support = self.is_support_user_txn(txn, user_id) - if not is_support: - # We do this manually here to avoid hitting https://github.com/matrix-org/synapse/issues/6791 - self.db_pool.simple_upsert_txn( - txn, - table="monthly_active_users", - keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, - ) - else: - logger.warning("mau limit reserved threepid %s not found in db" % tp) - async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server @@ -340,9 +267,9 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): Args: user_id: user to add/update """ - assert ( - self._update_on_this_worker - ), "This worker is not designated to update MAUs" + assert self._update_on_this_worker, ( + "This worker is not designated to update MAUs" + ) # Support user never to be included in MAU stats. Note I can't easily call this # from upsert_monthly_active_user_txn because then I need a _txn form of @@ -379,9 +306,9 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): txn: user_id: user to add/update """ - assert ( - self._update_on_this_worker - ), "This worker is not designated to update MAUs" + assert self._update_on_this_worker, ( + "This worker is not designated to update MAUs" + ) # Am consciously deciding to lock the table on the basis that is ought # never be a big table and alternative approaches (batching multiple @@ -409,9 +336,9 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): Args: user_id: the user_id to query """ - assert ( - self._update_on_this_worker - ), "This worker is not designated to update MAUs" + assert self._update_on_this_worker, ( + "This worker is not designated to update MAUs" + ) if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 996aea808d..30d8a58d96 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -18,8 +18,13 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING, Optional +import json +from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast +from canonicaljson import encode_canonical_json + +from synapse.api.constants import ProfileFields +from synapse.api.errors import Codes, StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -27,13 +32,17 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.roommember import ProfileInfo -from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, UserID +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import JsonDict, JsonValue, UserID if TYPE_CHECKING: from synapse.server import HomeServer +# The number of bytes that the serialized profile can have. +MAX_PROFILE_SIZE = 65536 + + class ProfileWorkerStore(SQLBaseStore): def __init__( self, @@ -144,6 +153,16 @@ class ProfileWorkerStore(SQLBaseStore): return 50 async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: + """ + Fetch the display name and avatar URL of a user. + + Args: + user_id: The user ID to fetch the profile for. + + Returns: + The user's display name and avatar URL. Values may be null if unset + or if the user doesn't exist. + """ profile = await self.db_pool.simple_select_one( table="profiles", keyvalues={"full_user_id": user_id.to_string()}, @@ -158,6 +177,15 @@ class ProfileWorkerStore(SQLBaseStore): return ProfileInfo(avatar_url=profile[1], display_name=profile[0]) async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: + """ + Fetch the display name of a user. + + Args: + user_id: The user to get the display name for. + + Raises: + 404 if the user does not exist. + """ return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"full_user_id": user_id.to_string()}, @@ -166,6 +194,15 @@ class ProfileWorkerStore(SQLBaseStore): ) async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]: + """ + Fetch the avatar URL of a user. + + Args: + user_id: The user to get the avatar URL for. + + Raises: + 404 if the user does not exist. + """ return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"full_user_id": user_id.to_string()}, @@ -173,7 +210,96 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_avatar_url", ) + async def get_profile_field(self, user_id: UserID, field_name: str) -> JsonValue: + """ + Get a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The custom profile field name. + + Returns: + The string value if the field exists, otherwise raises 404. + """ + + def get_profile_field(txn: LoggingTransaction) -> JsonValue: + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + field_path = f'$."{field_name}"' + + if isinstance(self.database_engine, PostgresEngine): + sql = """ + SELECT JSONB_PATH_EXISTS(fields, ?), JSONB_EXTRACT_PATH(fields, ?) + FROM profiles + WHERE user_id = ? + """ + txn.execute( + sql, + (field_path, field_name, user_id.localpart), + ) + + # Test exists first since value being None is used for both + # missing and a null JSON value. + exists, value = cast(Tuple[bool, JsonValue], txn.fetchone()) + if not exists: + raise StoreError(404, "No row found") + return value + + else: + sql = """ + SELECT JSON_TYPE(fields, ?), JSON_EXTRACT(fields, ?) + FROM profiles + WHERE user_id = ? + """ + txn.execute( + sql, + (field_path, field_path, user_id.localpart), + ) + + # If value_type is None, then the value did not exist. + value_type, value = cast( + Tuple[Optional[str], JsonValue], txn.fetchone() + ) + if not value_type: + raise StoreError(404, "No row found") + # If value_type is object or array, then need to deserialize the JSON. + # Scalar values are properly returned directly. + if value_type in ("object", "array"): + assert isinstance(value, str) + return json.loads(value) + return value + + return await self.db_pool.runInteraction("get_profile_field", get_profile_field) + + async def get_profile_fields(self, user_id: UserID) -> Dict[str, str]: + """ + Get all custom profile fields for a user. + + Args: + user_id: The user's ID. + + Returns: + A dictionary of custom profile fields. + """ + result = await self.db_pool.simple_select_one_onecol( + table="profiles", + keyvalues={"full_user_id": user_id.to_string()}, + retcol="fields", + desc="get_profile_fields", + ) + # The SQLite driver doesn't automatically convert JSON to + # Python objects + if isinstance(self.database_engine, Sqlite3Engine) and result: + result = json.loads(result) + return result or {} + async def create_profile(self, user_id: UserID) -> None: + """ + Create a blank profile for a user. + + Args: + user_id: The user to create the profile for. + """ user_localpart = user_id.localpart await self.db_pool.simple_insert( table="profiles", @@ -181,6 +307,71 @@ class ProfileWorkerStore(SQLBaseStore): desc="create_profile", ) + def _check_profile_size( + self, + txn: LoggingTransaction, + user_id: UserID, + new_field_name: str, + new_value: JsonValue, + ) -> None: + # For each entry there are 4 quotes (2 each for key and value), 1 colon, + # and 1 comma. + PER_VALUE_EXTRA = 6 + + # Add the size of the current custom profile fields, ignoring the entry + # which will be overwritten. + if isinstance(txn.database_engine, PostgresEngine): + size_sql = """ + SELECT + OCTET_LENGTH((fields - ?)::text), OCTET_LENGTH(displayname), OCTET_LENGTH(avatar_url) + FROM profiles + WHERE + user_id = ? + """ + txn.execute( + size_sql, + (new_field_name, user_id.localpart), + ) + else: + size_sql = """ + SELECT + LENGTH(json_remove(fields, ?)), LENGTH(displayname), LENGTH(avatar_url) + FROM profiles + WHERE + user_id = ? + """ + txn.execute( + size_sql, + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + (f'$."{new_field_name}"', user_id.localpart), + ) + row = cast(Tuple[Optional[int], Optional[int], Optional[int]], txn.fetchone()) + + # The values return null if the column is null. + total_bytes = ( + # Discount the opening and closing braces to avoid double counting, + # but add one for a comma. + # -2 + 1 = -1 + (row[0] - 1 if row[0] else 0) + + ( + row[1] + len("displayname") + PER_VALUE_EXTRA + if new_field_name != ProfileFields.DISPLAYNAME and row[1] + else 0 + ) + + ( + row[2] + len("avatar_url") + PER_VALUE_EXTRA + if new_field_name != ProfileFields.AVATAR_URL and row[2] + else 0 + ) + ) + + # Add the length of the field being added + the braces. + total_bytes += len(encode_canonical_json({new_field_name: new_value})) + + if total_bytes > MAX_PROFILE_SIZE: + raise StoreError(400, "Profile too large", Codes.PROFILE_TOO_LARGE) + async def set_profile_displayname( self, user_id: UserID, new_displayname: Optional[str] ) -> None: @@ -193,14 +384,25 @@ class ProfileWorkerStore(SQLBaseStore): name is removed. """ user_localpart = user_id.localpart - await self.db_pool.simple_upsert( - table="profiles", - keyvalues={"user_id": user_localpart}, - values={ - "displayname": new_displayname, - "full_user_id": user_id.to_string(), - }, - desc="set_profile_displayname", + + def set_profile_displayname(txn: LoggingTransaction) -> None: + if new_displayname is not None: + self._check_profile_size( + txn, user_id, ProfileFields.DISPLAYNAME, new_displayname + ) + + self.db_pool.simple_upsert_txn( + txn, + table="profiles", + keyvalues={"user_id": user_localpart}, + values={ + "displayname": new_displayname, + "full_user_id": user_id.to_string(), + }, + ) + + await self.db_pool.runInteraction( + "set_profile_displayname", set_profile_displayname ) async def set_profile_avatar_url( @@ -215,13 +417,125 @@ class ProfileWorkerStore(SQLBaseStore): removed. """ user_localpart = user_id.localpart - await self.db_pool.simple_upsert( - table="profiles", - keyvalues={"user_id": user_localpart}, - values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()}, - desc="set_profile_avatar_url", + + def set_profile_avatar_url(txn: LoggingTransaction) -> None: + if new_avatar_url is not None: + self._check_profile_size( + txn, user_id, ProfileFields.AVATAR_URL, new_avatar_url + ) + + self.db_pool.simple_upsert_txn( + txn, + table="profiles", + keyvalues={"user_id": user_localpart}, + values={ + "avatar_url": new_avatar_url, + "full_user_id": user_id.to_string(), + }, + ) + + await self.db_pool.runInteraction( + "set_profile_avatar_url", set_profile_avatar_url ) + async def set_profile_field( + self, user_id: UserID, field_name: str, new_value: JsonValue + ) -> None: + """ + Set a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The name of the custom profile field. + new_value: The value of the custom profile field. + """ + + # Encode to canonical JSON. + canonical_value = encode_canonical_json(new_value) + + def set_profile_field(txn: LoggingTransaction) -> None: + self._check_profile_size(txn, user_id, field_name, new_value) + + if isinstance(self.database_engine, PostgresEngine): + from psycopg2.extras import Json + + # Note that the || jsonb operator is not recursive, any duplicate + # keys will be taken from the second value. + sql = """ + INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_BUILD_OBJECT(?, ?::jsonb)) + ON CONFLICT (user_id) + DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = COALESCE(profiles.fields, '{}'::jsonb) || EXCLUDED.fields + """ + + txn.execute( + sql, + ( + user_id.localpart, + user_id.to_string(), + field_name, + # Pass as a JSON object since we have passing bytes disabled + # at the database driver. + Json(json.loads(canonical_value)), + ), + ) + else: + # You may be tempted to use json_patch instead of providing the parameters + # twice, but that recursively merges objects instead of replacing. + sql = """ + INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_OBJECT(?, JSON(?))) + ON CONFLICT (user_id) + DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = JSON_SET(COALESCE(profiles.fields, '{}'), ?, JSON(?)) + """ + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + json_field_name = f'$."{field_name}"' + + txn.execute( + sql, + ( + user_id.localpart, + user_id.to_string(), + json_field_name, + canonical_value, + json_field_name, + canonical_value, + ), + ) + + await self.db_pool.runInteraction("set_profile_field", set_profile_field) + + async def delete_profile_field(self, user_id: UserID, field_name: str) -> None: + """ + Remove a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The name of the custom profile field. + """ + + def delete_profile_field(txn: LoggingTransaction) -> None: + if isinstance(self.database_engine, PostgresEngine): + sql = """ + UPDATE profiles SET fields = fields - ? + WHERE user_id = ? + """ + txn.execute( + sql, + (field_name, user_id.localpart), + ) + else: + sql = """ + UPDATE profiles SET fields = json_remove(fields, ?) + WHERE user_id = ? + """ + txn.execute( + sql, + # This will error if field_name has double quotes in it. + (f'$."{field_name}"', user_id.localpart), + ) + + await self.db_pool.runInteraction("delete_profile_field", delete_profile_field) + class ProfileStore(ProfileWorkerStore): pass diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3b81ed943c..a11f522f03 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -20,7 +20,7 @@ # import logging -from typing import Any, List, Set, Tuple, cast +from typing import Any, Set, Tuple, cast from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction @@ -199,9 +199,8 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # Update backward extremeties txn.execute_batch( - "INSERT INTO event_backward_extremities (room_id, event_id)" - " VALUES (?, ?)", - [(room_id, event_id) for event_id, in new_backwards_extrems], + "INSERT INTO event_backward_extremities (room_id, event_id) VALUES (?, ?)", + [(room_id, event_id) for (event_id,) in new_backwards_extrems], ) logger.info("[purge] finding state groups referenced by deleted events") @@ -215,7 +214,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): """ ) - referenced_state_groups = {sg for sg, in txn} + referenced_state_groups = {sg for (sg,) in txn} logger.info( "[purge] found %i referenced state groups", len(referenced_state_groups) ) @@ -332,7 +331,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): return referenced_state_groups - async def purge_room(self, room_id: str) -> List[int]: + async def purge_room(self, room_id: str) -> None: """Deletes all record of a room Args: @@ -348,7 +347,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # purge any of those rows which were added during the first. logger.info("[purge] Starting initial main purge of [1/2]") - state_groups_to_delete = await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "purge_room", self._purge_room_txn, room_id=room_id, @@ -356,18 +355,15 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) logger.info("[purge] Starting secondary main purge of [2/2]") - state_groups_to_delete.extend( - await self.db_pool.runInteraction( - "purge_room", - self._purge_room_txn, - room_id=room_id, - ), + await self.db_pool.runInteraction( + "purge_room", + self._purge_room_txn, + room_id=room_id, ) - logger.info("[purge] Done with main purge") - return state_groups_to_delete + logger.info("[purge] Done with main purge") - def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: + def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> None: # This collides with event persistence so we cannot write new events and metadata into # a room while deleting it or this transaction will fail. if isinstance(self.database_engine, PostgresEngine): @@ -376,18 +372,10 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): (room_id,), ) - # First, fetch all the state groups that should be deleted, before - # we delete that information. - txn.execute( - """ - SELECT DISTINCT state_group FROM events - INNER JOIN event_to_state_groups USING(event_id) - WHERE events.room_id = ? - """, - (room_id,), - ) - - state_groups = [row[0] for row in txn] + if isinstance(self.database_engine, PostgresEngine): + # Disable statement timeouts for this transaction; purging rooms can + # take a while! + txn.execute("SET LOCAL statement_timeout = 0") # Get all the auth chains that are referenced by events that are to be # deleted. @@ -454,6 +442,10 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # so must be deleted first. "local_current_membership", "room_memberships", + # Note: the sliding_sync_ tables have foreign keys to the `events` table + # so must be deleted first. + "sliding_sync_joined_rooms", + "sliding_sync_membership_snapshots", "events", "federation_inbound_events_staging", "receipts_graph", @@ -504,5 +496,3 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # periodically anyway (https://github.com/matrix-org/synapse/issues/5888) self._invalidate_caches_for_room_and_stream(txn, room_id) - - return state_groups diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index bbdde17711..86c87f78bf 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -109,6 +109,7 @@ def _load_rules( msc3664_enabled=experimental_config.msc3664_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events, + msc4210_enabled=experimental_config.msc4210_enabled, ) return filtered_rules diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 3bde0ae0d4..9964331510 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -30,10 +30,12 @@ from typing import ( Mapping, Optional, Sequence, + Set, Tuple, cast, ) +import attr from immutabledict import immutabledict from synapse.api.constants import EduTypes @@ -43,6 +45,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_tuple_in_list_sql_clause, ) from synapse.storage.engines._base import IsolationLevel from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -51,10 +54,12 @@ from synapse.types import ( JsonMapping, MultiWriterStreamToken, PersistedPosition, + StrCollection, ) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -62,6 +67,57 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(auto_attribs=True, slots=True, frozen=True) +class ReceiptInRoom: + receipt_type: str + user_id: str + event_id: str + thread_id: Optional[str] + data: JsonMapping + + @staticmethod + def merge_to_content(receipts: Collection["ReceiptInRoom"]) -> JsonMapping: + """Merge the given set of receipts (in a room) into the receipt + content format. + + Returns: + A mapping of the combined receipts: event ID -> receipt type -> user + ID -> receipt data. + """ + # MSC4102: always replace threaded receipts with unthreaded ones if + # there is a clash. This means we will drop some receipts, but MSC4102 + # is designed to drop semantically meaningless receipts, so this is + # okay. Previously, we would drop meaningful data! + # + # We do this by finding the unthreaded receipts, and then filtering out + # matching threaded receipts. + + # Set of (user_id, event_id) + unthreaded_receipts: Set[Tuple[str, str]] = { + (receipt.user_id, receipt.event_id) + for receipt in receipts + if receipt.thread_id is None + } + + # event_id -> receipt_type -> user_id -> receipt data + content: Dict[str, Dict[str, Dict[str, JsonMapping]]] = {} + for receipt in receipts: + data = receipt.data + if receipt.thread_id is not None: + if (receipt.user_id, receipt.event_id) in unthreaded_receipts: + # Ignore threaded receipts if we have an unthreaded one. + continue + + data = dict(data) + data["thread_id"] = receipt.thread_id + + content.setdefault(receipt.event_id, {}).setdefault( + receipt.receipt_type, {} + )[receipt.user_id] = data + + return content + + class ReceiptsWorkerStore(SQLBaseStore): def __init__( self, @@ -398,7 +454,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def f( txn: LoggingTransaction, - ) -> List[Tuple[str, str, str, str, Optional[str], str]]: + ) -> Mapping[str, Sequence[ReceiptInRoom]]: if from_key: sql = """ SELECT stream_id, instance_name, room_id, receipt_type, @@ -428,50 +484,46 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args)) - return [ - (room_id, receipt_type, user_id, event_id, thread_id, data) - for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn - if MultiWriterStreamToken.is_stream_position_in_range( + results: Dict[str, List[ReceiptInRoom]] = {} + for ( + stream_id, + instance_name, + room_id, + receipt_type, + user_id, + event_id, + thread_id, + data, + ) in txn: + if not MultiWriterStreamToken.is_stream_position_in_range( from_key, to_key, instance_name, stream_id + ): + continue + + results.setdefault(room_id, []).append( + ReceiptInRoom( + receipt_type=receipt_type, + user_id=user_id, + event_id=event_id, + thread_id=thread_id, + data=db_to_json(data), + ) ) - ] + + return results txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f ) - results: JsonDict = {} - for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results: - # We want a single event per room, since we want to batch the - # receipts by room, event and type. - room_event = results.setdefault( - room_id, - {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}}, - ) - - # The content is of the form: - # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } - event_entry = room_event["content"].setdefault(event_id, {}) - receipt_type_dict = event_entry.setdefault(receipt_type, {}) - - # MSC4102: always replace threaded receipts with unthreaded ones if there is a clash. - # Specifically: - # - if there is no existing receipt, great, set the data. - # - if there is an existing receipt, is it threaded (thread_id present)? - # YES: replace if this receipt has no thread id. NO: do not replace. - # This means we will drop some receipts, but MSC4102 is designed to drop semantically - # meaningless receipts, so this is okay. Previously, we would drop meaningful data! - receipt_data = db_to_json(data) - if user_id in receipt_type_dict: # existing receipt - # is the existing receipt threaded and we are currently processing an unthreaded one? - if "thread_id" in receipt_type_dict[user_id] and not thread_id: - receipt_type_dict[user_id] = ( - receipt_data # replace with unthreaded one - ) - else: # receipt does not exist, just set it - receipt_type_dict[user_id] = receipt_data - if thread_id: - receipt_type_dict[user_id]["thread_id"] = thread_id + results: JsonDict = { + room_id: { + "room_id": room_id, + "type": EduTypes.RECEIPT, + "content": ReceiptInRoom.merge_to_content(receipts), + } + for room_id, receipts in txn_results.items() + } results = { room_id: [results[room_id]] if room_id in results else [] @@ -479,6 +531,69 @@ class ReceiptsWorkerStore(SQLBaseStore): } return results + async def get_linearized_receipts_for_events( + self, + room_and_event_ids: Collection[Tuple[str, str]], + ) -> Mapping[str, Sequence[ReceiptInRoom]]: + """Get all receipts for the given set of events. + + Arguments: + room_and_event_ids: A collection of 2-tuples of room ID and + event IDs to fetch receipts for + + Returns: + A list of receipts, one per room. + """ + if not room_and_event_ids: + return {} + + def get_linearized_receipts_for_events_txn( + txn: LoggingTransaction, + room_id_event_id_tuples: Collection[Tuple[str, str]], + ) -> List[Tuple[str, str, str, str, Optional[str], str]]: + clause, args = make_tuple_in_list_sql_clause( + self.database_engine, ("room_id", "event_id"), room_id_event_id_tuples + ) + + sql = f""" + SELECT room_id, receipt_type, user_id, event_id, thread_id, data + FROM receipts_linearized + WHERE {clause} + """ + + txn.execute(sql, args) + + return txn.fetchall() + + # room_id -> receipts + room_to_receipts: Dict[str, List[ReceiptInRoom]] = {} + for batch in batch_iter(room_and_event_ids, 1000): + batch_results = await self.db_pool.runInteraction( + "get_linearized_receipts_for_events", + get_linearized_receipts_for_events_txn, + batch, + ) + + for ( + room_id, + receipt_type, + user_id, + event_id, + thread_id, + data, + ) in batch_results: + room_to_receipts.setdefault(room_id, []).append( + ReceiptInRoom( + receipt_type=receipt_type, + user_id=user_id, + event_id=event_id, + thread_id=thread_id, + data=db_to_json(data), + ) + ) + + return room_to_receipts + @cached( num_args=2, ) @@ -550,6 +665,114 @@ class ReceiptsWorkerStore(SQLBaseStore): return results + async def get_linearized_receipts_for_user_in_rooms( + self, user_id: str, room_ids: StrCollection, to_key: MultiWriterStreamToken + ) -> Mapping[str, Sequence[ReceiptInRoom]]: + """Fetch all receipts for the user in the given room. + + Returns: + A dict from room ID to receipts in the room. + """ + + def get_linearized_receipts_for_user_in_rooms_txn( + txn: LoggingTransaction, + batch_room_ids: StrCollection, + ) -> List[Tuple[str, str, str, str, Optional[str], str]]: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", batch_room_ids + ) + + sql = f""" + SELECT instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data + FROM receipts_linearized + WHERE {clause} AND user_id = ? AND stream_id <= ? + """ + + args.append(user_id) + args.append(to_key.get_max_stream_pos()) + + txn.execute(sql, args) + + return [ + (room_id, receipt_type, user_id, event_id, thread_id, data) + for instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + low=None, + high=to_key, + instance_name=instance_name, + pos=stream_id, + ) + ] + + # room_id -> receipts + room_to_receipts: Dict[str, List[ReceiptInRoom]] = {} + for batch in batch_iter(room_ids, 1000): + batch_results = await self.db_pool.runInteraction( + "get_linearized_receipts_for_events", + get_linearized_receipts_for_user_in_rooms_txn, + batch, + ) + + for ( + room_id, + receipt_type, + user_id, + event_id, + thread_id, + data, + ) in batch_results: + room_to_receipts.setdefault(room_id, []).append( + ReceiptInRoom( + receipt_type=receipt_type, + user_id=user_id, + event_id=event_id, + thread_id=thread_id, + data=db_to_json(data), + ) + ) + + return room_to_receipts + + async def get_rooms_with_receipts_between( + self, + room_ids: StrCollection, + from_key: MultiWriterStreamToken, + to_key: MultiWriterStreamToken, + ) -> StrCollection: + """Given a set of room_ids, find out which ones (may) have receipts + between the two tokens (> `from_token` and <= `to_token`).""" + + room_ids = self._receipts_stream_cache.get_entities_changed( + room_ids, from_key.stream + ) + if not room_ids: + return [] + + def f(txn: LoggingTransaction, room_ids: StrCollection) -> StrCollection: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + sql = f""" + SELECT DISTINCT room_id FROM receipts_linearized + WHERE {clause} AND ? < stream_id AND stream_id <= ? + """ + args.append(from_key.stream) + args.append(to_key.get_max_stream_pos()) + + txn.execute(sql, args) + + return [room_id for (room_id,) in txn] + + results: List[str] = [] + for batch in batch_iter(room_ids, 1000): + batch_result = await self.db_pool.runInteraction( + "get_rooms_with_receipts_between", f, batch + ) + results.extend(batch_result) + + return results + async def get_users_sent_receipts_between( self, last_id: int, current_id: int ) -> List[str]: @@ -807,9 +1030,7 @@ class ReceiptsWorkerStore(SQLBaseStore): SELECT event_id WHERE room_id = ? AND stream_ordering IN ( SELECT max(stream_ordering) WHERE %s ) - """ % ( - clause, - ) + """ % (clause,) txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() @@ -954,6 +1175,12 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, self._background_receipts_graph_unique_index, ) + self.db_pool.updates.register_background_index_update( + update_name="receipts_room_id_event_id_index", + index_name="receipts_linearized_event_id", + table="receipts_linearized", + columns=("room_id", "event_id"), + ) async def _populate_receipt_event_stream_ordering( self, progress: JsonDict, batch_size: int diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index df7f8a43b7..868803e169 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -32,7 +32,6 @@ from synapse.api.errors import ( NotFoundError, StoreError, SynapseError, - ThreepidValidationError, ) from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -149,30 +148,6 @@ class LoginTokenLookupResult: """The session ID advertised by the SSO Identity Provider.""" -@attr.s(frozen=True, slots=True, auto_attribs=True) -class ThreepidResult: - medium: str - address: str - validated_at: int - added_at: int - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class ThreepidValidationSession: - address: str - """address of the 3pid""" - medium: str - """medium of the 3pid""" - client_secret: str - """a secret provided by the client for this validation session""" - session_id: str - """ID of the validation session""" - last_send_attempt: int - """a number serving to dedupe send attempts for this session""" - validated_at: Optional[int] - """timestamp of when this session was validated if so""" - - class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -215,12 +190,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): self._set_expiration_date_when_missing, ) - # Create a background job for culling expired 3PID validity tokens - if hs.config.worker.run_background_tasks: - self._clock.looping_call( - self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS - ) - @cached() async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]: """Returns info about the user account, if it exists.""" @@ -583,7 +552,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) - async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None: + async def set_user_type( + self, user: UserID, user_type: Optional[Union[UserTypes, str]] + ) -> None: """Sets the user type. Args: @@ -683,7 +654,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): retcol="user_type", allow_none=True, ) - return res is None + return res is None or res not in [UserTypes.BOT, UserTypes.SUPPORT] def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool: res = self.db_pool.simple_select_one_onecol_txn( @@ -759,17 +730,37 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): external_id: id on that system user_id: complete mxid that it is mapped to """ + self._invalidate_cache_and_stream( + txn, self.get_user_by_external_id, (auth_provider, external_id) + ) - self.db_pool.simple_insert_txn( + # This INSERT ... ON CONFLICT DO NOTHING statement will cause a + # 'could not serialize access due to concurrent update' + # if the row is added concurrently by another transaction. + # This is exactly what we want, as it makes the transaction get retried + # in a new snapshot where we can check for a genuine conflict. + was_inserted = self.db_pool.simple_upsert_txn( txn, table="user_external_ids", - values={ - "auth_provider": auth_provider, - "external_id": external_id, - "user_id": user_id, - }, + keyvalues={"auth_provider": auth_provider, "external_id": external_id}, + values={}, + insertion_values={"user_id": user_id}, ) + if not was_inserted: + existing_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="user_external_ids", + keyvalues={"auth_provider": auth_provider, "user_id": user_id}, + retcol="external_id", + allow_none=True, + ) + + if existing_id != external_id: + raise ExternalIDReuseException( + f"{user_id!r} has external id {existing_id!r} for {auth_provider} but trying to add {external_id!r}" + ) + async def remove_user_external_id( self, auth_provider: str, external_id: str, user_id: str ) -> None: @@ -789,6 +780,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): }, desc="remove_user_external_id", ) + await self.invalidate_cache_and_stream( + "get_user_by_external_id", (auth_provider, external_id) + ) async def replace_user_external_id( self, @@ -809,29 +803,20 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ExternalIDReuseException if the new external_id could not be mapped. """ - def _remove_user_external_ids_txn( + def _replace_user_external_id_txn( txn: LoggingTransaction, - user_id: str, ) -> None: - """Remove all mappings from external user ids to a mxid - If these mappings are not found, this method does nothing. - - Args: - user_id: complete mxid that it is mapped to - """ - self.db_pool.simple_delete_txn( txn, table="user_external_ids", keyvalues={"user_id": user_id}, ) - def _replace_user_external_id_txn( - txn: LoggingTransaction, - ) -> None: - _remove_user_external_ids_txn(txn, user_id) - for auth_provider, external_id in record_external_ids: + self._invalidate_cache_and_stream( + txn, self.get_user_by_external_id, (auth_provider, external_id) + ) + self._record_user_external_id_txn( txn, auth_provider, @@ -847,6 +832,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): except self.database_engine.module.IntegrityError: raise ExternalIDReuseException() + @cached() async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: @@ -944,10 +930,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return await self.db_pool.runInteraction("count_users", _count_users) async def count_real_users(self) -> int: - """Counts all users without a special user_type registered on the homeserver.""" + """Counts all users without the bot or support user_types registered on the homeserver.""" def _count_users(txn: LoggingTransaction) -> int: - txn.execute("SELECT COUNT(*) FROM users where user_type is null") + txn.execute( + f"SELECT COUNT(*) FROM users WHERE user_type IS NULL OR user_type NOT IN ('{UserTypes.BOT}', '{UserTypes.SUPPORT}')" + ) row = txn.fetchone() assert row is not None return row[0] @@ -965,161 +953,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return str(next_id) - async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: - """Returns user id from threepid - - Args: - medium: threepid medium e.g. email - address: threepid address e.g. me@example.com. This must already be - in canonical form. - - Returns: - The user ID or None if no user id/threepid mapping exists - """ - user_id = await self.db_pool.runInteraction( - "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address - ) - return user_id - - def get_user_id_by_threepid_txn( - self, txn: LoggingTransaction, medium: str, address: str - ) -> Optional[str]: - """Returns user id from threepid - - Args: - txn: - medium: threepid medium e.g. email - address: threepid address e.g. me@example.com - - Returns: - user id, or None if no user id/threepid mapping exists - """ - return self.db_pool.simple_select_one_onecol_txn( - txn, - "user_threepids", - {"medium": medium, "address": address}, - "user_id", - True, - ) - - async def user_add_threepid( - self, - user_id: str, - medium: str, - address: str, - validated_at: int, - added_at: int, - ) -> None: - await self.db_pool.simple_upsert( - "user_threepids", - {"medium": medium, "address": address}, - {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, - ) - - async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]: - results = cast( - List[Tuple[str, str, int, int]], - await self.db_pool.simple_select_list( - "user_threepids", - keyvalues={"user_id": user_id}, - retcols=["medium", "address", "validated_at", "added_at"], - desc="user_get_threepids", - ), - ) - return [ - ThreepidResult( - medium=r[0], - address=r[1], - validated_at=r[2], - added_at=r[3], - ) - for r in results - ] - - async def user_delete_threepid( - self, user_id: str, medium: str, address: str - ) -> None: - await self.db_pool.simple_delete( - "user_threepids", - keyvalues={"user_id": user_id, "medium": medium, "address": address}, - desc="user_delete_threepid", - ) - - async def add_user_bound_threepid( - self, user_id: str, medium: str, address: str, id_server: str - ) -> None: - """The server proxied a bind request to the given identity server on - behalf of the given user. We need to remember this in case the user - asks us to unbind the threepid. - - Args: - user_id - medium - address - id_server - """ - # We need to use an upsert, in case they user had already bound the - # threepid - await self.db_pool.simple_upsert( - table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - "id_server": id_server, - }, - values={}, - insertion_values={}, - desc="add_user_bound_threepid", - ) - - async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]: - """Get the threepids that a user has bound to an identity server through the homeserver - The homeserver remembers where binds to an identity server occurred. Using this - method can retrieve those threepids. - - Args: - user_id: The ID of the user to retrieve threepids for - - Returns: - List of tuples of two strings: - medium: The medium of the threepid (e.g "email") - address: The address of the threepid (e.g "bob@example.com") - """ - return cast( - List[Tuple[str, str]], - await self.db_pool.simple_select_list( - table="user_threepid_id_server", - keyvalues={"user_id": user_id}, - retcols=["medium", "address"], - desc="user_get_bound_threepids", - ), - ) - - async def remove_user_bound_threepid( - self, user_id: str, medium: str, address: str, id_server: str - ) -> None: - """The server proxied an unbind request to the given identity server on - behalf of the given user, so we remove the mapping of threepid to - identity server. - - Args: - user_id - medium - address - id_server - """ - await self.db_pool.simple_delete( - table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - "id_server": id_server, - }, - desc="remove_user_bound_threepid", - ) - async def get_id_servers_user_bound( self, user_id: str, medium: str, address: str ) -> List[str]: @@ -1204,123 +1037,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return bool(res) - async def get_threepid_validation_session( - self, - medium: Optional[str], - client_secret: str, - address: Optional[str] = None, - sid: Optional[str] = None, - validated: Optional[bool] = True, - ) -> Optional[ThreepidValidationSession]: - """Gets a session_id and last_send_attempt (if available) for a - combination of validation metadata - - Args: - medium: The medium of the 3PID - client_secret: A unique string provided by the client to help identify this - validation attempt - address: The address of the 3PID - sid: The ID of the validation session - validated: Whether sessions should be filtered by - whether they have been validated already or not. None to - perform no filtering - - Returns: - A ThreepidValidationSession or None if a validation session is not found - """ - if not client_secret: - raise SynapseError( - 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM - ) - - keyvalues = {"client_secret": client_secret} - if medium: - keyvalues["medium"] = medium - if address: - keyvalues["address"] = address - if sid: - keyvalues["session_id"] = sid - - assert address or sid - - def get_threepid_validation_session_txn( - txn: LoggingTransaction, - ) -> Optional[ThreepidValidationSession]: - sql = """ - SELECT address, session_id, medium, client_secret, - last_send_attempt, validated_at - FROM threepid_validation_session WHERE %s - """ % ( - " AND ".join("%s = ?" % k for k in keyvalues.keys()), - ) - - if validated is not None: - sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") - - sql += " LIMIT 1" - - txn.execute(sql, list(keyvalues.values())) - row = txn.fetchone() - if not row: - return None - - return ThreepidValidationSession( - address=row[0], - session_id=row[1], - medium=row[2], - client_secret=row[3], - last_send_attempt=row[4], - validated_at=row[5], - ) - - return await self.db_pool.runInteraction( - "get_threepid_validation_session", get_threepid_validation_session_txn - ) - - async def delete_threepid_session(self, session_id: str) -> None: - """Removes a threepid validation session from the database. This can - be done after validation has been performed and whatever action was - waiting on it has been carried out - - Args: - session_id: The ID of the session to delete - """ - - def delete_threepid_session_txn(txn: LoggingTransaction) -> None: - self.db_pool.simple_delete_txn( - txn, - table="threepid_validation_token", - keyvalues={"session_id": session_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - ) - - await self.db_pool.runInteraction( - "delete_threepid_session", delete_threepid_session_txn - ) - - @wrap_as_background_process("cull_expired_threepid_validation_tokens") - async def cull_expired_threepid_validation_tokens(self) -> None: - """Remove threepid validation tokens with expiry dates that have passed""" - - def cull_expired_threepid_validation_tokens_txn( - txn: LoggingTransaction, ts: int - ) -> None: - sql = """ - DELETE FROM threepid_validation_token WHERE - expires < ? - """ - txn.execute(sql, (ts,)) - - await self.db_pool.runInteraction( - "cull_expired_threepid_validation_tokens", - cull_expired_threepid_validation_tokens_txn, - self._clock.time_msec(), - ) - @wrap_as_background_process("account_validity_set_expiration_dates") async def _set_expiration_date_when_missing(self) -> None: """ @@ -1512,15 +1228,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. - pending, completed = cast( - Tuple[int, int], - self.db_pool.simple_select_one_txn( - txn, - "registration_tokens", - keyvalues={"token": token}, - retcols=["pending", "completed"], - ), + row = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=("pending", "completed"), ) + pending = int(row[0]) + completed = int(row[1]) # Decrement pending and increment completed self.db_pool.simple_update_one_txn( @@ -2093,6 +1808,136 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): func=is_user_approved_txn, ) + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: + """Set the `deactivated` property for the provided user to the provided value. + + Args: + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. + """ + + await self.db_pool.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) + + def set_user_deactivated_status_txn( + self, txn: LoggingTransaction, user_id: str, deactivated: bool + ) -> None: + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + self._invalidate_cache_and_stream(txn, self.is_guest, (user_id,)) + + async def set_user_suspended_status(self, user_id: str, suspended: bool) -> None: + """ + Set whether the user's account is suspended in the `users` table. + + Args: + user_id: The user ID of the user in question + suspended: True if the user is suspended, false if not + """ + await self.db_pool.runInteraction( + "set_user_suspended_status", + self.set_user_suspended_status_txn, + user_id, + suspended, + ) + + def set_user_suspended_status_txn( + self, txn: LoggingTransaction, user_id: str, suspended: bool + ) -> None: + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"suspended": suspended}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_suspended_status, (user_id,) + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + async def set_user_locked_status(self, user_id: str, locked: bool) -> None: + """Set the `locked` property for the provided user to the provided value. + + Args: + user_id: The ID of the user to set the status for. + locked: The value to set for `locked`. + """ + + await self.db_pool.runInteraction( + "set_user_locked_status", + self.set_user_locked_status_txn, + user_id, + locked, + ) + + def set_user_locked_status_txn( + self, txn: LoggingTransaction, user_id: str, locked: bool + ) -> None: + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"locked": locked}, + ) + self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,)) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + async def update_user_approval_status( + self, user_id: UserID, approved: bool + ) -> None: + """Set the user's 'approved' flag to the given value. + + The boolean will be turned into an int (in update_user_approval_status_txn) + because the column is a smallint. + + Args: + user_id: the user to update the flag for. + approved: the value to set the flag to. + """ + await self.db_pool.runInteraction( + "update_user_approval_status", + self.update_user_approval_status_txn, + user_id.to_string(), + approved, + ) + + def update_user_approval_status_txn( + self, txn: LoggingTransaction, user_id: str, approved: bool + ) -> None: + """Set the user's 'approved' flag to the given value. + + The boolean is turned into an int because the column is a smallint. + + Args: + txn: the current database transaction. + user_id: the user to update the flag for. + approved: the value to set the flag to. + """ + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"approved": approved}, + ) + + # Invalidate the caches of methods that read the value of the 'approved' flag. + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,)) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -2205,117 +2050,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): return nb_processed - async def set_user_deactivated_status( - self, user_id: str, deactivated: bool - ) -> None: - """Set the `deactivated` property for the provided user to the provided value. - - Args: - user_id: The ID of the user to set the status for. - deactivated: The value to set for `deactivated`. - """ - - await self.db_pool.runInteraction( - "set_user_deactivated_status", - self.set_user_deactivated_status_txn, - user_id, - deactivated, - ) - - def set_user_deactivated_status_txn( - self, txn: LoggingTransaction, user_id: str, deactivated: bool - ) -> None: - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"deactivated": 1 if deactivated else 0}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,) - ) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - txn.call_after(self.is_guest.invalidate, (user_id,)) - - async def set_user_suspended_status(self, user_id: str, suspended: bool) -> None: - """ - Set whether the user's account is suspended in the `users` table. - - Args: - user_id: The user ID of the user in question - suspended: True if the user is suspended, false if not - """ - await self.db_pool.runInteraction( - "set_user_suspended_status", - self.set_user_suspended_status_txn, - user_id, - suspended, - ) - - def set_user_suspended_status_txn( - self, txn: LoggingTransaction, user_id: str, suspended: bool - ) -> None: - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"suspended": suspended}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_suspended_status, (user_id,) - ) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - - async def set_user_locked_status(self, user_id: str, locked: bool) -> None: - """Set the `locked` property for the provided user to the provided value. - - Args: - user_id: The ID of the user to set the status for. - locked: The value to set for `locked`. - """ - - await self.db_pool.runInteraction( - "set_user_locked_status", - self.set_user_locked_status_txn, - user_id, - locked, - ) - - def set_user_locked_status_txn( - self, txn: LoggingTransaction, user_id: str, locked: bool - ) -> None: - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"locked": locked}, - ) - self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,)) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - - def update_user_approval_status_txn( - self, txn: LoggingTransaction, user_id: str, approved: bool - ) -> None: - """Set the user's 'approved' flag to the given value. - - The boolean is turned into an int because the column is a smallint. - - Args: - txn: the current database transaction. - user_id: the user to update the flag for. - approved: the value to set the flag to. - """ - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"approved": approved}, - ) - - # Invalidate the caches of methods that read the value of the 'approved' flag. - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,)) - class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__( @@ -2326,9 +2060,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): ): super().__init__(database, db_conn, hs) - self._ignore_unknown_session_error = ( - hs.config.server.request_token_inhibit_3pid_errors - ) + self._ignore_unknown_session_error = False # Used to use whether 3pid errors were suppressed or not... Problem? self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") @@ -2514,7 +2246,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): the user, setting their displayname to the given value admin: is an admin user? user_type: type of user. One of the values from api.constants.UserTypes, - or None for a normal user. + a custom value set in the configuration file, or None for a normal + user. shadow_banned: Whether the user is shadow-banned, i.e. they may be told their requests succeeded but we ignore them. approved: Whether to consider the user has already been approved by an @@ -2796,96 +2529,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): desc="add_user_pending_deactivation", ) - async def validate_threepid_session( - self, session_id: str, client_secret: str, token: str, current_ts: int - ) -> Optional[str]: - """Attempt to validate a threepid session using a token - - Args: - session_id: The id of a validation session - client_secret: A unique string provided by the client to help identify - this validation attempt - token: A validation token - current_ts: The current unix time in milliseconds. Used for checking - token expiry status - - Raises: - ThreepidValidationError: if a matching validation token was not found or has - expired - - Returns: - A str representing a link to redirect the user to if there is one. - """ - - # Insert everything into a transaction in order to run atomically - def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]: - row = self.db_pool.simple_select_one_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - retcols=["client_secret", "validated_at"], - allow_none=True, - ) - - if not row: - if self._ignore_unknown_session_error: - # If we need to inhibit the error caused by an incorrect session ID, - # use None as placeholder values for the client secret and the - # validation timestamp. - # It shouldn't be an issue because they're both only checked after - # the token check, which should fail. And if it doesn't for some - # reason, the next check is on the client secret, which is NOT NULL, - # so we don't have to worry about the client secret matching by - # accident. - row = None, None - else: - raise ThreepidValidationError("Unknown session_id") - - retrieved_client_secret, validated_at = row - - row = self.db_pool.simple_select_one_txn( - txn, - table="threepid_validation_token", - keyvalues={"session_id": session_id, "token": token}, - retcols=["expires", "next_link"], - allow_none=True, - ) - - if not row: - raise ThreepidValidationError( - "Validation token not found or has expired" - ) - expires, next_link = row - - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - "This client_secret does not match the provided session_id" - ) - - # If the session is already validated, no need to revalidate - if validated_at: - return next_link - - if expires <= current_ts: - raise ThreepidValidationError( - "This token has expired. Please request a new one" - ) - - # Looks good. Validate the session - self.db_pool.simple_update_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self._clock.time_msec()}, - ) - - return next_link - - # Return next_link if it exists - return await self.db_pool.runInteraction( - "validate_threepid_session_txn", validate_threepid_session_txn - ) - async def start_or_continue_validation_session( self, medium: str, @@ -2944,25 +2587,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) - async def update_user_approval_status( - self, user_id: UserID, approved: bool - ) -> None: - """Set the user's 'approved' flag to the given value. - - The boolean will be turned into an int (in update_user_approval_status_txn) - because the column is a smallint. - - Args: - user_id: the user to update the flag for. - approved: the value to set the flag to. - """ - await self.db_pool.runInteraction( - "update_user_approval_status", - self.update_user_approval_status_txn, - user_id.to_string(), - approved, - ) - @wrap_as_background_process("delete_expired_login_tokens") async def _delete_expired_login_tokens(self) -> None: """Remove login tokens with expiry dates that have passed.""" diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 80a4bf95f2..347dbbba6b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -51,11 +51,15 @@ from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream -from synapse.storage._base import db_to_json, make_in_list_sql_clause +from synapse.storage._base import ( + db_to_json, + make_in_list_sql_clause, +) from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_tuple_in_list_sql_clause, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor @@ -73,6 +77,8 @@ logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) class RatelimitOverride: + # n.b. elsewhere in Synapse messages_per_second is represented as a float, but it is + # an integer in the database messages_per_second: int burst_count: int @@ -604,6 +610,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): search_term: Optional[str], public_rooms: Optional[bool], empty_rooms: Optional[bool], + emma_include_tombstone: bool = False, ) -> Tuple[List[Dict[str, Any]], int]: """Function to retrieve a paginated list of rooms as json. @@ -623,6 +630,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): If true, empty rooms are queried. if false, empty rooms are excluded from the query. When it is none (the default), both empty rooms and none-empty rooms are queried. + emma_include_tombstone: If true, include tombstone events in the results. Returns: A list of room dicts and an integer representing the total number of rooms that exist given this query @@ -791,11 +799,43 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): room_count = cast(Tuple[int], txn.fetchone()) return rooms, room_count[0] - return await self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_rooms_paginate", _get_rooms_paginate_txn, ) + if emma_include_tombstone: + room_id_sql, room_id_args = make_in_list_sql_clause( + self.database_engine, "cse.room_id", [r["room_id"] for r in result[0]] + ) + + tombstone_sql = """ + SELECT cse.room_id, cse.event_id, ej.json + FROM current_state_events cse + JOIN event_json ej USING (event_id) + WHERE cse.type = 'm.room.tombstone' + AND {room_id_sql} + """.format( + room_id_sql=room_id_sql + ) + + def _get_tombstones_txn( + txn: LoggingTransaction, + ) -> Dict[str, JsonDict]: + txn.execute(tombstone_sql, room_id_args) + for room_id, event_id, json in txn: + for result_room in result[0]: + if result_room["room_id"] == room_id: + result_room["gay.rory.synapse_admin_extensions.tombstone"] = db_to_json(json) + break + return result[0], result[1] + + result = await self.db_pool.runInteraction( + "get_rooms_tombstones", _get_tombstones_txn, + ) + + return result + @cached(max_entries=10000) async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]: """Check if there are any overrides for ratelimiting for the given user @@ -1127,6 +1167,109 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return local_media_ids + def _quarantine_local_media_txn( + self, + txn: LoggingTransaction, + hashes: Set[str], + media_ids: Set[str], + quarantined_by: Optional[str], + ) -> int: + """Quarantine and unquarantine local media items. + + Args: + txn (cursor) + hashes: A set of sha256 hashes for any media that should be quarantined + media_ids: A set of media IDs for any media that should be quarantined + quarantined_by: The ID of the user who initiated the quarantine request + If it is `None` media will be removed from quarantine + Returns: + The total number of media items quarantined + """ + total_media_quarantined = 0 + + # Effectively a legacy path, update any media that was explicitly named. + if media_ids: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "media_id", media_ids + ) + sql = f""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + + if quarantined_by is not None: + sql += " AND safe_from_quarantine = FALSE" + + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + # Note that a rowcount of -1 can be used to indicate no rows were affected. + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + # Update any media that was identified via hash. + if hashes: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "sha256", hashes + ) + sql = f""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + + if quarantined_by is not None: + sql += " AND safe_from_quarantine = FALSE" + + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + return total_media_quarantined + + def _quarantine_remote_media_txn( + self, + txn: LoggingTransaction, + hashes: Set[str], + media: Set[Tuple[str, str]], + quarantined_by: Optional[str], + ) -> int: + """Quarantine and unquarantine remote items + + Args: + txn (cursor) + hashes: A set of sha256 hashes for any media that should be quarantined + media_ids: A set of tuples (media_origin, media_id) for any media that should be quarantined + quarantined_by: The ID of the user who initiated the quarantine request + If it is `None` media will be removed from quarantine + Returns: + The total number of media items quarantined + """ + total_media_quarantined = 0 + + if media: + sql_in_list_clause, sql_args = make_tuple_in_list_sql_clause( + txn.database_engine, + ("media_origin", "media_id"), + media, + ) + sql = f""" + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE {sql_in_list_clause}""" + + txn.execute(sql, [quarantined_by] + sql_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + total_media_quarantined = 0 + if hashes: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "sha256", hashes + ) + sql = f""" + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + return total_media_quarantined + def _quarantine_media_txn( self, txn: LoggingTransaction, @@ -1146,40 +1289,93 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): Returns: The total number of media items quarantined """ - - # Update all the tables to set the quarantined_by flag - sql = """ - UPDATE local_media_repository - SET quarantined_by = ? - WHERE media_id = ? - """ - - # set quarantine - if quarantined_by is not None: - sql += "AND safe_from_quarantine = FALSE" - txn.executemany( - sql, [(quarantined_by, media_id) for media_id in local_mxcs] + hashes = set() + media_ids = set() + remote_media = set() + + # First, determine the hashes of the media we want to delete. + # We also want the media_ids for any media that lacks a hash. + if local_mxcs: + hash_sql_many_clause_sql, hash_sql_many_clause_args = ( + make_in_list_sql_clause(txn.database_engine, "media_id", local_mxcs) ) - # remove from quarantine - else: - txn.executemany( - sql, [(quarantined_by, media_id) for media_id in local_mxcs] + hash_sql = f"SELECT sha256, media_id FROM local_media_repository WHERE {hash_sql_many_clause_sql}" + if quarantined_by is not None: + hash_sql += " AND safe_from_quarantine = FALSE" + + txn.execute(hash_sql, hash_sql_many_clause_args) + for sha256, media_id in txn: + if sha256: + hashes.add(sha256) + else: + media_ids.add(media_id) + + # Do the same for remote media + if remote_mxcs: + hash_sql_in_list_clause, hash_sql_args = make_tuple_in_list_sql_clause( + txn.database_engine, + ("media_origin", "media_id"), + remote_mxcs, ) - # Note that a rowcount of -1 can be used to indicate no rows were affected. - total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 + hash_sql = f"SELECT sha256, media_origin, media_id FROM remote_media_cache WHERE {hash_sql_in_list_clause}" + txn.execute(hash_sql, hash_sql_args) + for sha256, media_origin, media_id in txn: + if sha256: + hashes.add(sha256) + else: + remote_media.add((media_origin, media_id)) - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin = ? AND media_id = ? - """, - ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs), + count = self._quarantine_local_media_txn(txn, hashes, media_ids, quarantined_by) + count += self._quarantine_remote_media_txn( + txn, hashes, remote_media, quarantined_by ) - total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 - return total_media_quarantined + return count + + async def block_room(self, room_id: str, user_id: str) -> None: + """Marks the room as blocked. + + Can be called multiple times (though we'll only track the last user to + block this room). + + Can be called on a room unknown to this homeserver. + + Args: + room_id: Room to block + user_id: Who blocked it + """ + await self.db_pool.simple_upsert( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={"user_id": user_id}, + desc="block_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) + + async def unblock_room(self, room_id: str) -> None: + """Remove the room from blocking list. + + Args: + room_id: Room to unblock + """ + await self.db_pool.simple_delete( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + desc="unblock_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False @@ -1382,6 +1578,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): partial_state_rooms = {row[0] for row in rows} return {room_id: room_id in partial_state_rooms for room_id in room_ids} + @cached(max_entries=10000, iterable=True) + async def get_partial_rooms(self) -> AbstractSet[str]: + """Get any "partial-state" rooms which the user is in. + + This is fast as the set of partially stated rooms at any point across + the whole server is small, and so such a query is fast. This is also + faster than looking up whether a set of room ID's are partially stated + via `is_partial_state_room_batched(...)` because of the sheer amount of + CPU time looking all the rooms up in the cache. + """ + + def _get_partial_rooms_for_user_txn( + txn: LoggingTransaction, + ) -> AbstractSet[str]: + sql = """ + SELECT room_id FROM partial_state_rooms + """ + txn.execute(sql) + return {room_id for (room_id,) in txn} + + return await self.db_pool.runInteraction( + "get_partial_rooms_for_user", _get_partial_rooms_for_user_txn + ) + async def get_join_event_id_and_device_lists_stream_id_for_partial_state( self, room_id: str ) -> Tuple[str, int]: @@ -1562,6 +1782,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): direction: Direction = Direction.BACKWARDS, user_id: Optional[str] = None, room_id: Optional[str] = None, + event_sender_user_id: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], int]: """Retrieve a paginated list of event reports @@ -1572,6 +1793,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): oldest first (forwards) user_id: search for user_id. Ignored if user_id is None room_id: search for room_id. Ignored if room_id is None + event_sender_user_id: search for the sender of the reported event. Ignored if + event_sender_user_id is None Returns: Tuple of: json list of event reports @@ -1591,6 +1814,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): filters.append("er.room_id LIKE ?") args.extend(["%" + room_id + "%"]) + if event_sender_user_id: + filters.append("events.sender = ?") + args.extend([event_sender_user_id]) + if direction == Direction.BACKWARDS: order = "DESC" else: @@ -1606,11 +1833,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): sql = """ SELECT COUNT(*) as total_event_reports FROM event_reports AS er + LEFT JOIN events USING(event_id) JOIN room_stats_state ON room_stats_state.room_id = er.room_id {} - """.format( - where_clause - ) + """.format(where_clause) txn.execute(sql, args) count = cast(Tuple[int], txn.fetchone())[0] @@ -1626,8 +1852,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): room_stats_state.canonical_alias, room_stats_state.name FROM event_reports AS er - LEFT JOIN events - ON events.event_id = er.event_id + LEFT JOIN events USING(event_id) JOIN room_stats_state ON room_stats_state.room_id = er.room_id {where_clause} @@ -2343,6 +2568,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self._invalidate_cache_and_stream( txn, self._get_partial_state_servers_at_join, (room_id,) ) + self._invalidate_all_cache_and_stream(txn, self.get_partial_rooms) async def write_partial_state_rooms_join_event_id( self, @@ -2470,50 +2696,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): ) return next_id - async def block_room(self, room_id: str, user_id: str) -> None: - """Marks the room as blocked. - - Can be called multiple times (though we'll only track the last user to - block this room). - - Can be called on a room unknown to this homeserver. - - Args: - room_id: Room to block - user_id: Who blocked it - """ - await self.db_pool.simple_upsert( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - values={}, - insertion_values={"user_id": user_id}, - desc="block_room", - ) - await self.db_pool.runInteraction( - "block_room_invalidation", - self._invalidate_cache_and_stream, - self.is_room_blocked, - (room_id,), - ) - - async def unblock_room(self, room_id: str) -> None: - """Remove the room from blocking list. - - Args: - room_id: Room to unblock - """ - await self.db_pool.simple_delete( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - desc="unblock_room", - ) - await self.db_pool.runInteraction( - "block_room_invalidation", - self._invalidate_cache_and_stream, - self.is_room_blocked, - (room_id,), - ) - async def clear_partial_state_room(self, room_id: str) -> Optional[int]: """Clears the partial state flag for a room. @@ -2527,7 +2709,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): still contains events with partial state. """ try: - async with self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id: + async with ( + self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id + ): await self.db_pool.runInteraction( "clear_partial_state_room", self._clear_partial_state_room_txn, @@ -2564,6 +2748,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self._invalidate_cache_and_stream( txn, self._get_partial_state_servers_at_join, (room_id,) ) + self._invalidate_all_cache_and_stream(txn, self.get_partial_rooms) DatabasePool.simple_insert_txn( txn, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 1d9f0f52e1..7ca73abb83 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -19,6 +19,7 @@ # # import logging +from http import HTTPStatus from typing import ( TYPE_CHECKING, AbstractSet, @@ -39,6 +40,8 @@ from typing import ( import attr from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.logging.opentracing import trace from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -50,13 +53,20 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.stream import _filter_results_by_stream from synapse.storage.engines import Sqlite3Engine -from synapse.storage.roommember import MemberSummary, ProfileInfo, RoomsForUser +from synapse.storage.roommember import ( + MemberSummary, + ProfileInfo, + RoomsForUser, + RoomsForUserSlidingSync, +) from synapse.types import ( JsonDict, PersistedEventPosition, StateMap, StrCollection, + StreamToken, get_domain_from_id, ) from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -71,6 +81,7 @@ logger = logging.getLogger(__name__) _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" +_POPULATE_PARTICIPANT_BG_UPDATE_BATCH_SIZE = 1000 @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -225,9 +236,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): AND m.room_id = c.room_id AND m.user_id = c.state_key WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s - """ % ( - clause, - ) + """ % (clause,) txn.execute(sql, (room_id, Membership.JOIN, *ids)) return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn} @@ -306,18 +315,10 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): # We do this all in one transaction to keep the cache small. # FIXME: get rid of this when we have room_stats - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT count(*), membership FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - GROUP BY membership - """ + counts = self._get_member_counts_txn(txn, room_id) - txn.execute(sql, (room_id,)) res: Dict[str, MemberSummary] = {} - for count, membership in txn: + for membership, count in counts.items(): res.setdefault(membership, MemberSummary([], count)) # Order by membership (joins -> invites -> leave (former insiders) -> @@ -364,6 +365,31 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) @cached() + async def get_member_counts(self, room_id: str) -> Mapping[str, int]: + """Get a mapping of number of users by membership""" + + return await self.db_pool.runInteraction( + "get_member_counts", self._get_member_counts_txn, room_id + ) + + def _get_member_counts_txn( + self, txn: LoggingTransaction, room_id: str + ) -> Dict[str, int]: + """Get a mapping of number of users by membership""" + + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT count(*), membership FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + GROUP BY membership + """ + + txn.execute(sql, (room_id,)) + return {membership: count for count, membership in txn} + + @cached() async def get_number_joined_users_in_room(self, room_id: str) -> int: return await self.db_pool.simple_select_one_onecol( table="current_state_events", @@ -524,9 +550,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): WHERE user_id = ? AND %s - """ % ( - clause, - ) + """ % (clause,) txn.execute(sql, (user_id, *args)) results = [ @@ -631,10 +655,8 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ # Paranoia check. if not self.hs.is_mine_id(user_id): - raise Exception( - "Cannot call 'get_local_current_membership_for_user_in_room' on " - "non-local user %s" % (user_id,), - ) + message = f"Provided user_id {user_id} is a non-local user" + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON) results = cast( Optional[Tuple[str, str]], @@ -692,6 +714,27 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): return {row[0] for row in txn} + async def get_rooms_user_currently_banned_from( + self, user_id: str + ) -> FrozenSet[str]: + """Returns a set of room_ids the user is currently banned from. + + If a remote user only returns rooms this server is currently + participating in. + """ + room_ids = await self.db_pool.simple_select_onecol( + table="current_state_events", + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.BAN, + "state_key": user_id, + }, + retcol="room_id", + desc="get_rooms_user_currently_banned_from", + ) + + return frozenset(room_ids) + @cached(max_entries=500000, iterable=True) async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. @@ -808,7 +851,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ txn.execute(sql, (user_id, *args)) - return {u: True for u, in txn} + return {u: True for (u,) in txn} to_return = {} for batch_user_ids in batch_iter(other_user_ids, 1000): @@ -828,6 +871,73 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): return {u for u, share_room in user_dict.items() if share_room} + @cached(max_entries=10000) + async def does_pair_of_users_share_a_room_joined_or_invited( + self, user_id: str, other_user_id: str + ) -> bool: + raise NotImplementedError() + + @cachedList( + cached_method_name="does_pair_of_users_share_a_room_joined_or_invited", + list_name="other_user_ids", + ) + async def _do_users_share_a_room_joined_or_invited( + self, user_id: str, other_user_ids: Collection[str] + ) -> Mapping[str, Optional[bool]]: + """Return mapping from user ID to whether they share a room with the + given user via being either joined or invited. + + Note: `None` and `False` are equivalent and mean they don't share a + room. + """ + + def do_users_share_a_room_joined_or_invited_txn( + txn: LoggingTransaction, user_ids: Collection[str] + ) -> Dict[str, bool]: + clause, args = make_in_list_sql_clause( + self.database_engine, "state_key", user_ids + ) + + # This query works by fetching both the list of rooms for the target + # user and the set of other users, and then checking if there is any + # overlap. + sql = f""" + SELECT DISTINCT b.state_key + FROM ( + SELECT room_id FROM current_state_events + WHERE type = 'm.room.member' AND (membership = 'join' OR membership = 'invite') AND state_key = ? + ) AS a + INNER JOIN ( + SELECT room_id, state_key FROM current_state_events + WHERE type = 'm.room.member' AND (membership = 'join' OR membership = 'invite') AND {clause} + ) AS b using (room_id) + """ + + txn.execute(sql, (user_id, *args)) + return {u: True for (u,) in txn} + + to_return = {} + for batch_user_ids in batch_iter(other_user_ids, 1000): + res = await self.db_pool.runInteraction( + "do_users_share_a_room_joined_or_invited", + do_users_share_a_room_joined_or_invited_txn, + batch_user_ids, + ) + to_return.update(res) + + return to_return + + async def do_users_share_a_room_joined_or_invited( + self, user_id: str, other_user_ids: Collection[str] + ) -> Set[str]: + """Return the set of users who share a room with the first users via being either joined or invited""" + + user_dict = await self._do_users_share_a_room_joined_or_invited( + user_id, other_user_ids + ) + + return {u for u, share_room in user_dict.items() if share_room} + async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]: """Returns the set of users who share a room with `user_id`""" room_ids = await self.get_rooms_for_user(user_id) @@ -1026,7 +1136,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): AND room_id = ? """ txn.execute(sql, (room_id,)) - return {d for d, in txn} + return {d for (d,) in txn} return await self.db_pool.runInteraction( "get_current_hosts_in_room", get_current_hosts_in_room_txn @@ -1094,7 +1204,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ txn.execute(sql, (room_id,)) # `server_domain` will be `NULL` for malformed MXIDs with no colons. - return tuple(d for d, in txn if d is not None) + return tuple(d for (d,) in txn if d is not None) return await self.db_pool.runInteraction( "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn @@ -1311,9 +1421,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): room_id = ? AND membership = ? AND NOT (%s) LIMIT 1 - """ % ( - clause, - ) + """ % (clause,) def _is_local_host_in_room_ignoring_users_txn( txn: LoggingTransaction, @@ -1337,11 +1445,23 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): keyvalues={"user_id": user_id, "room_id": room_id}, updatevalues={"forgotten": 1}, ) + # Handle updating the `sliding_sync_membership_snapshots` table + self.db_pool.simple_update_txn( + txn, + table="sliding_sync_membership_snapshots", + keyvalues={"user_id": user_id, "room_id": room_id}, + updatevalues={"forgotten": 1}, + ) self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) self._invalidate_cache_and_stream( txn, self.get_forgotten_rooms_for_user, (user_id,) ) + self._invalidate_cache_and_stream( + txn, + self.get_sliding_sync_rooms_for_user_from_membership_snapshots, + (user_id,), + ) await self.db_pool.runInteraction("forget_membership", f) @@ -1371,6 +1491,360 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): desc="room_forgetter_stream_pos", ) + @cached(iterable=True, max_entries=10000) + async def get_sliding_sync_rooms_for_user_from_membership_snapshots( + self, user_id: str + ) -> Mapping[str, RoomsForUserSlidingSync]: + """ + Get all the rooms for a user to handle a sliding sync request from the + `sliding_sync_membership_snapshots` table. These will be current memberships and + need to be rewound to the token range. + + Ignores forgotten rooms and rooms that the user has left themselves. + + Args: + user_id: The user ID to get the rooms for. + + Returns: + Map from room ID to membership info + """ + + def _txn( + txn: LoggingTransaction, + ) -> Dict[str, RoomsForUserSlidingSync]: + # XXX: If you use any new columns that can change (like from + # `sliding_sync_joined_rooms` or `forgotten`), make sure to bust the + # `get_sliding_sync_rooms_for_user_from_membership_snapshots` cache in the + # appropriate places (and add tests). + sql = """ + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + COALESCE(j.room_type, m.room_type), + COALESCE(j.is_encrypted, m.is_encrypted) + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + LEFT JOIN sliding_sync_joined_rooms AS j ON (j.room_id = m.room_id AND m.membership = 'join') + WHERE user_id = ? + AND m.forgotten = 0 + AND (m.membership != 'leave' OR m.user_id != m.sender) + """ + txn.execute(sql, (user_id,)) + + return { + row[0]: RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=bool(row[9]), + ) + for row in txn + # We filter out unknown room versions proactively. They + # shouldn't go down sync and their metadata may be in a broken + # state (causing errors). + if row[4] in KNOWN_ROOM_VERSIONS + } + + return await self.db_pool.runInteraction( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", + _txn, + ) + + async def get_sliding_sync_self_leave_rooms_after_to_token( + self, + user_id: str, + to_token: StreamToken, + ) -> Dict[str, RoomsForUserSlidingSync]: + """ + Get all the self-leave rooms for a user after the `to_token` (outside the token + range) that are potentially relevant[1] and needed to handle a sliding sync + request. The results are from the `sliding_sync_membership_snapshots` table and + will be current memberships and need to be rewound to the token range. + + [1] If a leave happens after the token range, we may have still been joined (or + any non-self-leave which is relevant to sync) to the room before so we need to + include it in the list of potentially relevant rooms and apply + our rewind logic (outside of this function) to see if it's actually relevant. + + This is basically a sister-function to + `get_sliding_sync_rooms_for_user_from_membership_snapshots`. We could + alternatively incorporate this logic into + `get_sliding_sync_rooms_for_user_from_membership_snapshots` but those results + are cached and the `to_token` isn't very cache friendly (people are constantly + requesting with new tokens) so we separate it out here. + + Args: + user_id: The user ID to get the rooms for. + to_token: Any self-leave memberships after this position will be returned. + + Returns: + Map from room ID to membership info + """ + # TODO: Potential to check + # `self._membership_stream_cache.has_entity_changed(...)` as an early-return + # shortcut. + + def _txn( + txn: LoggingTransaction, + ) -> Dict[str, RoomsForUserSlidingSync]: + sql = """ + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + m.room_type, + m.is_encrypted + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + WHERE user_id = ? + AND m.forgotten = 0 + AND m.membership = 'leave' + AND m.user_id = m.sender + AND (m.event_stream_ordering > ?) + """ + # If a leave happens after the token range, we may have still been joined + # (or any non-self-leave which is relevant to sync) to the room before so we + # need to include it in the list of potentially relevant rooms and apply our + # rewind logic (outside of this function). + # + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_to_token_position = to_token.room_key.stream + txn.execute(sql, (user_id, min_to_token_position)) + + # Map from room_id to membership info + room_membership_for_user_map: Dict[str, RoomsForUserSlidingSync] = {} + for row in txn: + room_for_user = RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=bool(row[9]), + ) + + # We filter out unknown room versions proactively. They shouldn't go + # down sync and their metadata may be in a broken state (causing + # errors). + if row[4] not in KNOWN_ROOM_VERSIONS: + continue + + # We only want to include the self-leave membership if it happened after + # the token range. + # + # Since the database pulls out more than necessary, we need to filter it + # down here. + if _filter_results_by_stream( + lower_token=None, + upper_token=to_token.room_key, + instance_name=room_for_user.event_pos.instance_name, + stream_ordering=room_for_user.event_pos.stream, + ): + continue + + room_membership_for_user_map[room_for_user.room_id] = room_for_user + + return room_membership_for_user_map + + return await self.db_pool.runInteraction( + "get_sliding_sync_self_leave_rooms_after_to_token", + _txn, + ) + + async def get_sliding_sync_room_for_user( + self, user_id: str, room_id: str + ) -> Optional[RoomsForUserSlidingSync]: + """Get the sliding sync room entry for the given user and room.""" + + def get_sliding_sync_room_for_user_txn( + txn: LoggingTransaction, + ) -> Optional[RoomsForUserSlidingSync]: + sql = """ + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + COALESCE(j.room_type, m.room_type), + COALESCE(j.is_encrypted, m.is_encrypted) + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + LEFT JOIN sliding_sync_joined_rooms AS j ON (j.room_id = m.room_id AND m.membership = 'join') + WHERE user_id = ? + AND m.forgotten = 0 + AND m.room_id = ? + """ + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + if not row: + return None + + return RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=row[9], + ) + + return await self.db_pool.runInteraction( + "get_sliding_sync_room_for_user", get_sliding_sync_room_for_user_txn + ) + + async def get_sliding_sync_room_for_user_batch( + self, user_id: str, room_ids: StrCollection + ) -> Dict[str, RoomsForUserSlidingSync]: + """Get the sliding sync room entry for the given user and rooms.""" + + if not room_ids: + return {} + + def get_sliding_sync_room_for_user_batch_txn( + txn: LoggingTransaction, + ) -> Dict[str, RoomsForUserSlidingSync]: + clause, args = make_in_list_sql_clause( + self.database_engine, "m.room_id", room_ids + ) + sql = f""" + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + COALESCE(j.room_type, m.room_type), + COALESCE(j.is_encrypted, m.is_encrypted) + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + LEFT JOIN sliding_sync_joined_rooms AS j ON (j.room_id = m.room_id AND m.membership = 'join') + WHERE m.forgotten = 0 + AND {clause} + AND user_id = ? + """ + args.append(user_id) + txn.execute(sql, args) + + return { + row[0]: RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=row[9], + ) + for row in txn + } + + return await self.db_pool.runInteraction( + "get_sliding_sync_room_for_user_batch", + get_sliding_sync_room_for_user_batch_txn, + ) + + async def get_rooms_for_user_by_date( + self, user_id: str, from_ts: int + ) -> FrozenSet[str]: + """ + Fetch a list of rooms that the user has joined at or after the given timestamp, including + those they subsequently have left/been banned from. + + Args: + user_id: user ID of the user to search for + from_ts: a timestamp in ms from the unix epoch at which to begin the search at + """ + + def _get_rooms_for_user_by_join_date_txn( + txn: LoggingTransaction, user_id: str, timestamp: int + ) -> frozenset: + sql = """ + SELECT rm.room_id + FROM room_memberships AS rm + INNER JOIN events AS e USING (event_id) + WHERE rm.user_id = ? + AND rm.membership = 'join' + AND e.type = 'm.room.member' + AND e.received_ts >= ? + """ + txn.execute(sql, (user_id, timestamp)) + return frozenset([r[0] for r in txn]) + + return await self.db_pool.runInteraction( + "_get_rooms_for_user_by_join_date_txn", + _get_rooms_for_user_by_join_date_txn, + user_id, + from_ts, + ) + + async def set_room_participation(self, user_id: str, room_id: str) -> None: + """ + Record the provided user as participating in the given room + + Args: + user_id: the user ID of the user + room_id: ID of the room to set the participant in + """ + + def _set_room_participation_txn( + txn: LoggingTransaction, user_id: str, room_id: str + ) -> None: + sql = """ + UPDATE room_memberships + SET participant = true + WHERE event_id IN ( + SELECT event_id FROM local_current_membership + WHERE user_id = ? AND room_id = ? + ) + AND NOT participant + """ + txn.execute(sql, (user_id, room_id)) + + await self.db_pool.runInteraction( + "_set_room_participation_txn", _set_room_participation_txn, user_id, room_id + ) + + async def get_room_participation(self, user_id: str, room_id: str) -> bool: + """ + Check whether a user is listed as a participant in a room + + Args: + user_id: user ID of the user + room_id: ID of the room to check in + """ + + def _get_room_participation_txn( + txn: LoggingTransaction, user_id: str, room_id: str + ) -> bool: + sql = """ + SELECT participant + FROM local_current_membership AS l + INNER JOIN room_memberships AS r USING (event_id) + WHERE l.user_id = ? + AND l.room_id = ? + """ + txn.execute(sql, (user_id, room_id)) + res = txn.fetchone() + if res: + return res[0] + return False + + return await self.db_pool.runInteraction( + "_get_room_participation_txn", _get_room_participation_txn, user_id, room_id + ) + class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__( @@ -1405,10 +1879,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): self, progress: JsonDict, batch_size: int ) -> int: target_min_stream_id = progress.get( - "target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined] + "target_min_stream_id_inclusive", + self._min_stream_order_on_start, # type: ignore[attr-defined] ) max_stream_id = progress.get( - "max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined] + "max_stream_id_exclusive", + self._stream_order_on_start + 1, # type: ignore[attr-defined] ) def add_membership_profile_txn(txn: LoggingTransaction) -> int: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 20fcfd3122..1d5c5e72ff 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py
@@ -94,7 +94,7 @@ class SearchWorkerStore(SQLBaseStore): VALUES (?,?,?,to_tsvector('english', ?),?,?) """ - args1 = ( + args1 = [ ( entry.event_id, entry.room_id, @@ -104,7 +104,7 @@ class SearchWorkerStore(SQLBaseStore): entry.origin_server_ts, ) for entry in entries - ) + ] txn.execute_batch(sql, args1) @@ -177,9 +177,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): AND (%s) ORDER BY stream_ordering DESC LIMIT ? - """ % ( - " OR ".join("type = '%s'" % (t,) for t in TYPES), - ) + """ % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) diff --git a/synapse/storage/databases/main/sliding_sync.py b/synapse/storage/databases/main/sliding_sync.py new file mode 100644
index 0000000000..6a62b11d1e --- /dev/null +++ b/synapse/storage/databases/main/sliding_sync.py
@@ -0,0 +1,603 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2023, 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + + +import logging +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, cast + +import attr + +from synapse.api.errors import SlidingSyncUnknownPosition +from synapse.logging.opentracing import log_kv +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.types import MultiWriterStreamToken, RoomStreamToken +from synapse.types.handlers.sliding_sync import ( + HaveSentRoom, + HaveSentRoomFlag, + MutablePerConnectionState, + PerConnectionState, + RoomStatusMap, + RoomSyncConfig, +) +from synapse.util import json_encoder +from synapse.util.caches.descriptors import cached + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore + +logger = logging.getLogger(__name__) + + +class SlidingSyncStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + update_name="sliding_sync_connection_room_configs_required_state_id_idx", + index_name="sliding_sync_connection_room_configs_required_state_id_idx", + table="sliding_sync_connection_room_configs", + columns=("required_state_id",), + ) + + self.db_pool.updates.register_background_index_update( + update_name="sliding_sync_membership_snapshots_membership_event_id_idx", + index_name="sliding_sync_membership_snapshots_membership_event_id_idx", + table="sliding_sync_membership_snapshots", + columns=("membership_event_id",), + ) + + self.db_pool.updates.register_background_index_update( + update_name="sliding_sync_membership_snapshots_user_id_stream_ordering", + index_name="sliding_sync_membership_snapshots_user_id_stream_ordering", + table="sliding_sync_membership_snapshots", + columns=("user_id", "event_stream_ordering"), + replaces_index="sliding_sync_membership_snapshots_user_id", + ) + + async def get_latest_bump_stamp_for_room( + self, + room_id: str, + ) -> Optional[int]: + """ + Get the `bump_stamp` for the room. + + The `bump_stamp` is the `stream_ordering` of the last event according to the + `bump_event_types`. This helps clients sort more readily without them needing to + pull in a bunch of the timeline to determine the last activity. + `bump_event_types` is a thing because for example, we don't want display name + changes to mark the room as unread and bump it to the top. For encrypted rooms, + we just have to consider any activity as a bump because we can't see the content + and the client has to figure it out for themselves. + + This should only be called where the server is participating + in the room (someone local is joined). + + Returns: + The `bump_stamp` for the room (which can be `None`). + """ + + return cast( + Optional[int], + await self.db_pool.simple_select_one_onecol( + table="sliding_sync_joined_rooms", + keyvalues={"room_id": room_id}, + retcol="bump_stamp", + # FIXME: This should be `False` once we bump `SCHEMA_COMPAT_VERSION` and run the + # foreground update for + # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked + # by https://github.com/element-hq/synapse/issues/17623) + # + # The should be `allow_none=False` in the future because event though + # `bump_stamp` itself can be `None`, we should have a row in the + # `sliding_sync_joined_rooms` table for any joined room. + allow_none=True, + ), + ) + + async def persist_per_connection_state( + self, + user_id: str, + device_id: str, + conn_id: str, + previous_connection_position: Optional[int], + per_connection_state: "MutablePerConnectionState", + ) -> int: + """Persist updates to the per-connection state for a sliding sync + connection. + + Returns: + The connection position of the newly persisted state. + """ + + # This cast is safe because the downstream code only cares about + # `store.get_id_for_instance(...)` and `StreamWorkerStore` is mixed + # alongside `SlidingSyncStore` wherever we create a store. + store = cast("DataStore", self) + + return await self.db_pool.runInteraction( + "persist_per_connection_state", + self.persist_per_connection_state_txn, + user_id=user_id, + device_id=device_id, + conn_id=conn_id, + previous_connection_position=previous_connection_position, + per_connection_state=await PerConnectionStateDB.from_state( + per_connection_state, store + ), + ) + + def persist_per_connection_state_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + conn_id: str, + previous_connection_position: Optional[int], + per_connection_state: "PerConnectionStateDB", + ) -> int: + # First we fetch (or create) the connection key associated with the + # previous connection position. + if previous_connection_position is not None: + # The `previous_connection_position` is a user-supplied value, so we + # need to make sure that the one they supplied is actually theirs. + sql = """ + SELECT connection_key + FROM sliding_sync_connection_positions + INNER JOIN sliding_sync_connections USING (connection_key) + WHERE + connection_position = ? + AND user_id = ? AND effective_device_id = ? AND conn_id = ? + """ + txn.execute( + sql, (previous_connection_position, user_id, device_id, conn_id) + ) + row = txn.fetchone() + if row is None: + raise SlidingSyncUnknownPosition() + + (connection_key,) = row + else: + # We're restarting the connection, so we clear the previous existing data we + # used to track it. We do this here to ensure that if we get lots of + # one-shot requests we don't stack up lots of entries. We have `ON DELETE + # CASCADE` setup on the dependent tables so this will clear out all the + # associated data. + self.db_pool.simple_delete_txn( + txn, + table="sliding_sync_connections", + keyvalues={ + "user_id": user_id, + "effective_device_id": device_id, + "conn_id": conn_id, + }, + ) + + (connection_key,) = self.db_pool.simple_insert_returning_txn( + txn, + table="sliding_sync_connections", + values={ + "user_id": user_id, + "effective_device_id": device_id, + "conn_id": conn_id, + "created_ts": self._clock.time_msec(), + }, + returning=("connection_key",), + ) + + # Define a new connection position for the updates + (connection_position,) = self.db_pool.simple_insert_returning_txn( + txn, + table="sliding_sync_connection_positions", + values={ + "connection_key": connection_key, + "created_ts": self._clock.time_msec(), + }, + returning=("connection_position",), + ) + + # We need to deduplicate the `required_state` JSON. We do this by + # fetching all JSON associated with the connection and comparing that + # with the updates to `required_state` + + # Dict from required state json -> required state ID + required_state_to_id: Dict[str, int] = {} + if previous_connection_position is not None: + rows = self.db_pool.simple_select_list_txn( + txn, + table="sliding_sync_connection_required_state", + keyvalues={"connection_key": connection_key}, + retcols=("required_state_id", "required_state"), + ) + for required_state_id, required_state in rows: + required_state_to_id[required_state] = required_state_id + + room_to_state_ids: Dict[str, int] = {} + unique_required_state: Dict[str, List[str]] = {} + for room_id, room_state in per_connection_state.room_configs.items(): + serialized_state = json_encoder.encode( + # We store the required state as a sorted list of event type / + # state key tuples. + sorted( + (event_type, state_key) + for event_type, state_keys in room_state.required_state_map.items() + for state_key in state_keys + ) + ) + + existing_state_id = required_state_to_id.get(serialized_state) + if existing_state_id is not None: + room_to_state_ids[room_id] = existing_state_id + else: + unique_required_state.setdefault(serialized_state, []).append(room_id) + + # Insert any new `required_state` json we haven't previously seen. + for serialized_required_state, room_ids in unique_required_state.items(): + (required_state_id,) = self.db_pool.simple_insert_returning_txn( + txn, + table="sliding_sync_connection_required_state", + values={ + "connection_key": connection_key, + "required_state": serialized_required_state, + }, + returning=("required_state_id",), + ) + for room_id in room_ids: + room_to_state_ids[room_id] = required_state_id + + # Copy over state from the previous connection position (we'll overwrite + # these rows with any changes). + if previous_connection_position is not None: + sql = """ + INSERT INTO sliding_sync_connection_streams + (connection_position, stream, room_id, room_status, last_token) + SELECT ?, stream, room_id, room_status, last_token + FROM sliding_sync_connection_streams + WHERE connection_position = ? + """ + txn.execute(sql, (connection_position, previous_connection_position)) + + sql = """ + INSERT INTO sliding_sync_connection_room_configs + (connection_position, room_id, timeline_limit, required_state_id) + SELECT ?, room_id, timeline_limit, required_state_id + FROM sliding_sync_connection_room_configs + WHERE connection_position = ? + """ + txn.execute(sql, (connection_position, previous_connection_position)) + + # We now upsert the changes to the various streams. + key_values = [] + value_values = [] + for room_id, have_sent_room in per_connection_state.rooms._statuses.items(): + key_values.append((connection_position, "rooms", room_id)) + value_values.append( + (have_sent_room.status.value, have_sent_room.last_token) + ) + + for room_id, have_sent_room in per_connection_state.receipts._statuses.items(): + key_values.append((connection_position, "receipts", room_id)) + value_values.append( + (have_sent_room.status.value, have_sent_room.last_token) + ) + + for ( + room_id, + have_sent_room, + ) in per_connection_state.account_data._statuses.items(): + key_values.append((connection_position, "account_data", room_id)) + value_values.append( + (have_sent_room.status.value, have_sent_room.last_token) + ) + + self.db_pool.simple_upsert_many_txn( + txn, + table="sliding_sync_connection_streams", + key_names=( + "connection_position", + "stream", + "room_id", + ), + key_values=key_values, + value_names=( + "room_status", + "last_token", + ), + value_values=value_values, + ) + + # ... and upsert changes to the room configs. + keys = [] + values = [] + for room_id, room_config in per_connection_state.room_configs.items(): + keys.append((connection_position, room_id)) + values.append((room_config.timeline_limit, room_to_state_ids[room_id])) + + self.db_pool.simple_upsert_many_txn( + txn, + table="sliding_sync_connection_room_configs", + key_names=( + "connection_position", + "room_id", + ), + key_values=keys, + value_names=( + "timeline_limit", + "required_state_id", + ), + value_values=values, + ) + + return connection_position + + @cached(iterable=True, max_entries=100000) + async def get_and_clear_connection_positions( + self, user_id: str, device_id: str, conn_id: str, connection_position: int + ) -> "PerConnectionState": + """Get the per-connection state for the given connection position.""" + + per_connection_state_db = await self.db_pool.runInteraction( + "get_and_clear_connection_positions", + self._get_and_clear_connection_positions_txn, + user_id=user_id, + device_id=device_id, + conn_id=conn_id, + connection_position=connection_position, + ) + + # This cast is safe because the downstream code only cares about + # `store.get_id_for_instance(...)` and `StreamWorkerStore` is mixed + # alongside `SlidingSyncStore` wherever we create a store. + store = cast("DataStore", self) + + return await per_connection_state_db.to_state(store) + + def _get_and_clear_connection_positions_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + conn_id: str, + connection_position: int, + ) -> "PerConnectionStateDB": + # The `previous_connection_position` is a user-supplied value, so we + # need to make sure that the one they supplied is actually theirs. + sql = """ + SELECT connection_key + FROM sliding_sync_connection_positions + INNER JOIN sliding_sync_connections USING (connection_key) + WHERE + connection_position = ? + AND user_id = ? AND effective_device_id = ? AND conn_id = ? + """ + txn.execute(sql, (connection_position, user_id, device_id, conn_id)) + row = txn.fetchone() + if row is None: + raise SlidingSyncUnknownPosition() + + (connection_key,) = row + + # Now that we have seen the client has received and used the connection + # position, we can delete all the other connection positions. + sql = """ + DELETE FROM sliding_sync_connection_positions + WHERE connection_key = ? AND connection_position != ? + """ + txn.execute(sql, (connection_key, connection_position)) + + # Fetch and create a mapping from required state ID to the actual + # required state for the connection. + rows = self.db_pool.simple_select_list_txn( + txn, + table="sliding_sync_connection_required_state", + keyvalues={"connection_key": connection_key}, + retcols=( + "required_state_id", + "required_state", + ), + ) + + required_state_map: Dict[int, Dict[str, Set[str]]] = {} + for row in rows: + state = required_state_map[row[0]] = {} + for event_type, state_key in db_to_json(row[1]): + state.setdefault(event_type, set()).add(state_key) + + # Get all the room configs, looking up the required state from the map + # above. + room_config_rows = self.db_pool.simple_select_list_txn( + txn, + table="sliding_sync_connection_room_configs", + keyvalues={"connection_position": connection_position}, + retcols=( + "room_id", + "timeline_limit", + "required_state_id", + ), + ) + + room_configs: Dict[str, RoomSyncConfig] = {} + for ( + room_id, + timeline_limit, + required_state_id, + ) in room_config_rows: + room_configs[room_id] = RoomSyncConfig( + timeline_limit=timeline_limit, + required_state_map=required_state_map[required_state_id], + ) + + # Now look up the per-room stream data. + rooms: Dict[str, HaveSentRoom[str]] = {} + receipts: Dict[str, HaveSentRoom[str]] = {} + account_data: Dict[str, HaveSentRoom[str]] = {} + + receipt_rows = self.db_pool.simple_select_list_txn( + txn, + table="sliding_sync_connection_streams", + keyvalues={"connection_position": connection_position}, + retcols=( + "stream", + "room_id", + "room_status", + "last_token", + ), + ) + for stream, room_id, room_status, last_token in receipt_rows: + have_sent_room: HaveSentRoom[str] = HaveSentRoom( + status=HaveSentRoomFlag(room_status), last_token=last_token + ) + if stream == "rooms": + rooms[room_id] = have_sent_room + elif stream == "receipts": + receipts[room_id] = have_sent_room + elif stream == "account_data": + account_data[room_id] = have_sent_room + else: + # For forwards compatibility we ignore unknown streams, as in + # future we want to be able to easily add more stream types. + logger.warning("Unrecognized sliding sync stream in DB %r", stream) + + return PerConnectionStateDB( + rooms=RoomStatusMap(rooms), + receipts=RoomStatusMap(receipts), + account_data=RoomStatusMap(account_data), + room_configs=room_configs, + ) + + +@attr.s(auto_attribs=True, frozen=True) +class PerConnectionStateDB: + """An equivalent to `PerConnectionState` that holds data in a format stored + in the DB. + + The principle difference is that the tokens for the different streams are + serialized to strings. + + When persisting this *only* contains updates to the state. + """ + + rooms: "RoomStatusMap[str]" + receipts: "RoomStatusMap[str]" + account_data: "RoomStatusMap[str]" + + room_configs: Mapping[str, "RoomSyncConfig"] + + @staticmethod + async def from_state( + per_connection_state: "MutablePerConnectionState", store: "DataStore" + ) -> "PerConnectionStateDB": + """Convert from a standard `PerConnectionState`""" + rooms = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + await status.last_token.to_string(store) + if status.last_token is not None + else None + ), + ) + for room_id, status in per_connection_state.rooms.get_updates().items() + } + + receipts = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + await status.last_token.to_string(store) + if status.last_token is not None + else None + ), + ) + for room_id, status in per_connection_state.receipts.get_updates().items() + } + + account_data = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + str(status.last_token) if status.last_token is not None else None + ), + ) + for room_id, status in per_connection_state.account_data.get_updates().items() + } + + log_kv( + { + "rooms": rooms, + "receipts": receipts, + "account_data": account_data, + "room_configs": per_connection_state.room_configs.maps[0], + } + ) + + return PerConnectionStateDB( + rooms=RoomStatusMap(rooms), + receipts=RoomStatusMap(receipts), + account_data=RoomStatusMap(account_data), + room_configs=per_connection_state.room_configs.maps[0], + ) + + async def to_state(self, store: "DataStore") -> "PerConnectionState": + """Convert into a standard `PerConnectionState`""" + rooms = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + await RoomStreamToken.parse(store, status.last_token) + if status.last_token is not None + else None + ), + ) + for room_id, status in self.rooms._statuses.items() + } + + receipts = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + await MultiWriterStreamToken.parse(store, status.last_token) + if status.last_token is not None + else None + ), + ) + for room_id, status in self.receipts._statuses.items() + } + + account_data = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + int(status.last_token) if status.last_token is not None else None + ), + ) + for room_id, status in self.account_data._statuses.items() + } + + return PerConnectionState( + rooms=RoomStatusMap(rooms), + receipts=RoomStatusMap(receipts), + account_data=RoomStatusMap(account_data), + room_configs=self.room_configs, + ) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 62bc4600fb..788f7d1e32 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -308,8 +308,24 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=10000) - async def get_room_type(self, room_id: str) -> Optional[str]: - raise NotImplementedError() + async def get_room_type(self, room_id: str) -> Union[Optional[str], Sentinel]: + """Fetch room type for given room. + + Since this function is cached, any missing values would be cached as + `None`. In order to distinguish between an unencrypted room that has + `None` encryption and a room that is unknown to the server where we + might want to omit the value (which would make it cached as `None`), + instead we use the sentinel value `ROOM_UNKNOWN_SENTINEL`. + """ + + try: + create_event = await self.get_create_event_for_room(room_id) + return create_event.content.get(EventContentFields.ROOM_TYPE) + except NotFoundError: + # We use the sentinel value to distinguish between `None` which is a + # valid room type and a room that is unknown to the server so the value + # is just unset. + return ROOM_UNKNOWN_SENTINEL @cachedList(cached_method_name="get_room_type", list_name="room_ids") async def bulk_get_room_type( @@ -535,7 +551,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="check_if_events_in_current_state", ) - return frozenset(event_id for event_id, in rows) + return frozenset(event_id for (event_id,) in rows) # FIXME: how should this be cached? @cancellable @@ -556,10 +572,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: Map from type/state_key to event ID. """ + if state_filter is None: + state_filter = StateFilter.all() - where_clause, where_args = ( - state_filter or StateFilter.all() - ).make_sql_filter_clause() + where_clause, where_args = (state_filter).make_sql_filter_clause() if not where_clause: # We delegate to the cached version @@ -568,7 +584,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, ) -> StateMap[str]: - results = StateMapWrapper(state_filter=state_filter or StateFilter.all()) + results = StateMapWrapper(state_filter=state_filter) sql = """ SELECT type, state_key, event_id FROM current_state_events @@ -665,7 +681,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): context: EventContext, ) -> None: """Update the state group for a partial state event""" - async with self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id: + async with ( + self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id + ): await self.db_pool.runInteraction( "update_state_for_partial_state_event", self._update_state_for_partial_state_event_txn, @@ -736,6 +754,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" + MEMBERS_CURRENT_STATE_UPDATE_NAME = "current_state_events_members_room_index" def __init__( self, @@ -764,6 +783,13 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms, ) + self.db_pool.updates.register_background_index_update( + self.MEMBERS_CURRENT_STATE_UPDATE_NAME, + index_name="current_state_events_members_room_index", + table="current_state_events", + columns=["room_id", "membership"], + where_clause="type='m.room.member'", + ) async def _background_remove_left_rooms( self, progress: JsonDict, batch_size: int diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 9ed39e688a..00f87cc3a1 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py
@@ -20,16 +20,25 @@ # import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple import attr from synapse.logging.opentracing import trace from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import _filter_results_by_stream -from synapse.types import RoomStreamToken +from synapse.types import RoomStreamToken, StrCollection from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.iterutils import batch_iter + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -53,6 +62,21 @@ class StateDeltasStore(SQLBaseStore): # attribute. TODO: can we get static analysis to enforce this? _curr_state_delta_stream_cache: StreamChangeCache + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + update_name="current_state_delta_stream_room_index", + index_name="current_state_delta_stream_room_idx", + table="current_state_delta_stream", + columns=("room_id", "stream_id"), + ) + async def get_partial_current_state_deltas( self, prev_stream_id: int, max_stream_id: int ) -> Tuple[int, List[StateDelta]]: @@ -74,9 +98,9 @@ class StateDeltasStore(SQLBaseStore): prev_stream_id = int(prev_stream_id) # check we're not going backwards - assert ( - prev_stream_id <= max_stream_id - ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" + assert prev_stream_id <= max_stream_id, ( + f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" + ) if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id @@ -160,38 +184,144 @@ class StateDeltasStore(SQLBaseStore): self._get_max_stream_id_in_current_state_deltas_txn, ) + def get_current_state_deltas_for_room_txn( + self, + txn: LoggingTransaction, + room_id: str, + *, + from_token: Optional[RoomStreamToken], + to_token: Optional[RoomStreamToken], + ) -> List[StateDelta]: + """ + Get the state deltas between two tokens. + + (> `from_token` and <= `to_token`) + """ + from_clause = "" + from_args = [] + if from_token is not None: + from_clause = "AND ? < stream_id" + from_args = [from_token.stream] + + to_clause = "" + to_args = [] + if to_token is not None: + to_clause = "AND stream_id <= ?" + to_args = [to_token.get_max_stream_pos()] + + sql = f""" + SELECT instance_name, stream_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE room_id = ? {from_clause} {to_clause} + ORDER BY stream_id ASC + """ + txn.execute(sql, [room_id] + from_args + to_args) + + return [ + StateDelta( + stream_id=row[1], + room_id=room_id, + event_type=row[2], + state_key=row[3], + event_id=row[4], + prev_event_id=row[5], + ) + for row in txn + if _filter_results_by_stream(from_token, to_token, row[0], row[1]) + ] + @trace async def get_current_state_deltas_for_room( - self, room_id: str, from_token: RoomStreamToken, to_token: RoomStreamToken + self, + room_id: str, + *, + from_token: Optional[RoomStreamToken], + to_token: Optional[RoomStreamToken], ) -> List[StateDelta]: - """Get the state deltas between two tokens.""" + """ + Get the state deltas between two tokens. + + (> `from_token` and <= `to_token`) + """ + # We can bail early if the `from_token` is after the `to_token` + if ( + to_token is not None + and from_token is not None + and to_token.is_before_or_eq(from_token) + ): + return [] - def get_current_state_deltas_for_room_txn( + if ( + from_token is not None + and not self._curr_state_delta_stream_cache.has_entity_changed( + room_id, from_token.stream + ) + ): + return [] + + return await self.db_pool.runInteraction( + "get_current_state_deltas_for_room", + self.get_current_state_deltas_for_room_txn, + room_id, + from_token=from_token, + to_token=to_token, + ) + + @trace + async def get_current_state_deltas_for_rooms( + self, + room_ids: StrCollection, + from_token: RoomStreamToken, + to_token: RoomStreamToken, + ) -> List[StateDelta]: + """Get the state deltas between two tokens for the set of rooms.""" + + room_ids = self._curr_state_delta_stream_cache.get_entities_changed( + room_ids, from_token.stream + ) + if not room_ids: + return [] + + def get_current_state_deltas_for_rooms_txn( txn: LoggingTransaction, + room_ids: StrCollection, ) -> List[StateDelta]: - sql = """ - SELECT instance_name, stream_id, type, state_key, event_id, prev_event_id + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + sql = f""" + SELECT instance_name, stream_id, room_id, type, state_key, event_id, prev_event_id FROM current_state_delta_stream - WHERE room_id = ? AND ? < stream_id AND stream_id <= ? + WHERE {clause} AND ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ - txn.execute( - sql, (room_id, from_token.stream, to_token.get_max_stream_pos()) - ) + args.append(from_token.stream) + args.append(to_token.get_max_stream_pos()) + + txn.execute(sql, args) return [ StateDelta( stream_id=row[1], - room_id=room_id, - event_type=row[2], - state_key=row[3], - event_id=row[4], - prev_event_id=row[5], + room_id=row[2], + event_type=row[3], + state_key=row[4], + event_id=row[5], + prev_event_id=row[6], ) for row in txn if _filter_results_by_stream(from_token, to_token, row[0], row[1]) ] - return await self.db_pool.runInteraction( - "get_current_state_deltas_for_room", get_current_state_deltas_for_room_txn - ) + results = [] + for batch in batch_iter(room_ids, 1000): + deltas = await self.db_pool.runInteraction( + "get_current_state_deltas_for_rooms", + get_current_state_deltas_for_rooms_txn, + batch, + ) + + results.extend(deltas) + + return results diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index e9f6a918c7..79c49e7fd9 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -161,7 +161,7 @@ class StatsStore(StateDeltasStore): LIMIT ? """ txn.execute(sql, (last_user_id, batch_size)) - return [r for r, in txn] + return [r for (r,) in txn] users_to_work_on = await self.db_pool.runInteraction( "_populate_stats_process_users", _get_next_batch @@ -207,7 +207,7 @@ class StatsStore(StateDeltasStore): LIMIT ? """ txn.execute(sql, (last_room_id, batch_size)) - return [r for r, in txn] + return [r for (r,) in txn] rooms_to_work_on = await self.db_pool.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch @@ -751,9 +751,7 @@ class StatsStore(StateDeltasStore): LEFT JOIN profiles AS p ON lmr.user_id = p.full_user_id {} GROUP BY lmr.user_id, displayname - """.format( - where_clause - ) + """.format(where_clause) # SQLite does not support SELECT COUNT(*) OVER() sql = """ diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 4989c960a6..3fda49f31f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -21,7 +21,7 @@ # # -""" This module is responsible for getting events from the DB for pagination +"""This module is responsible for getting events from the DB for pagination and event streaming. The order it returns events in depend on whether we are streaming forwards or @@ -50,6 +50,8 @@ from typing import ( Dict, Iterable, List, + Literal, + Mapping, Optional, Protocol, Set, @@ -60,7 +62,7 @@ from typing import ( import attr from immutabledict import immutabledict -from typing_extensions import Literal, assert_never +from typing_extensions import assert_never from twisted.internet import defer @@ -78,9 +80,10 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.roommember import RoomsForUserStateReset from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter @@ -107,7 +110,7 @@ class PaginateFunction(Protocol): to_key: Optional[RoomStreamToken] = None, direction: Direction = Direction.BACKWARDS, limit: int = 0, - ) -> Tuple[List[EventBase], RoomStreamToken]: ... + ) -> Tuple[List[EventBase], RoomStreamToken, bool]: ... # Used as return values for pagination APIs @@ -451,6 +454,8 @@ def _filter_results_by_stream( stream_ordering falls between the two tokens (taking a None token to mean unbounded). + The token range is defined by > `lower_token` and <= `upper_token`. + Used to filter results from fetching events in the DB against the given tokens. This is necessary to handle the case where the tokens include position maps, which we handle by fetching more than necessary from the DB @@ -678,7 +683,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): to_key: Optional[RoomStreamToken] = None, direction: Direction = Direction.BACKWARDS, limit: int = 0, - ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]: + ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken, bool]]: """Get new room events in stream ordering since `from_key`. Args: @@ -694,6 +699,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): A map from room id to a tuple containing: - list of recent events in the room - stream ordering key for the start of the chunk of events returned. + - a boolean to indicate if there were more events but we hit the limit When Direction.FORWARDS: from_key < x <= to_key, (ascending order) When Direction.BACKWARDS: from_key >= x > to_key, (descending order) @@ -749,6 +755,48 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if self._events_stream_cache.has_entity_changed(room_id, from_id) } + async def get_rooms_that_have_updates_since_sliding_sync_table( + self, + room_ids: StrCollection, + from_key: RoomStreamToken, + ) -> StrCollection: + """Return the rooms that probably have had updates since the given + token (changes that are > `from_key`).""" + # If the stream change cache is valid for the stream token, we can just + # use the result of that. + if from_key.stream >= self._events_stream_cache.get_earliest_known_position(): + return self._events_stream_cache.get_entities_changed( + room_ids, from_key.stream + ) + + def get_rooms_that_have_updates_since_sliding_sync_table_txn( + txn: LoggingTransaction, + ) -> StrCollection: + sql = """ + SELECT room_id + FROM sliding_sync_joined_rooms + WHERE {clause} + AND event_stream_ordering > ? + """ + + results: Set[str] = set() + for batch in batch_iter(room_ids, 1000): + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", batch + ) + + args.append(from_key.stream) + txn.execute(sql.format(clause=clause), args) + + results.update(row[0] for row in txn) + + return results + + return await self.db_pool.runInteraction( + "get_rooms_that_have_updates_since_sliding_sync_table", + get_rooms_that_have_updates_since_sliding_sync_table_txn, + ) + async def paginate_room_events_by_stream_ordering( self, *, @@ -757,7 +805,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): to_key: Optional[RoomStreamToken] = None, direction: Direction = Direction.BACKWARDS, limit: int = 0, - ) -> Tuple[List[EventBase], RoomStreamToken]: + ) -> Tuple[List[EventBase], RoomStreamToken, bool]: """ Paginate events by `stream_ordering` in the room from the `from_key` in the given `direction` to the `to_key` or `limit`. @@ -772,8 +820,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): limit: Maximum number of events to return Returns: - The results as a list of events and a token that points to the end - of the result set. If no events are returned then the end of the + The results as a list of events, a token that points to the end of + the result set, and a boolean to indicate if there were more events + but we hit the limit. If no events are returned then the end of the stream has been reached (i.e. there are no events between `from_key` and `to_key`). @@ -797,7 +846,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and to_key.is_before_or_eq(from_key) ): # Token selection matches what we do below if there are no rows - return [], to_key if to_key else from_key + return [], to_key if to_key else from_key, False # Or vice-versa, if we're looking backwards and our `from_key` is already before # our `to_key`. elif ( @@ -806,7 +855,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and from_key.is_before_or_eq(to_key) ): # Token selection matches what we do below if there are no rows - return [], to_key if to_key else from_key + return [], to_key if to_key else from_key, False # We can do a quick sanity check to see if any events have been sent in the room # since the earlier token. @@ -825,7 +874,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if not has_changed: # Token selection matches what we do below if there are no rows - return [], to_key if to_key else from_key + return [], to_key if to_key else from_key, False order, from_bound, to_bound = generate_pagination_bounds( direction, from_key, to_key @@ -841,7 +890,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): engine=self.database_engine, ) - def f(txn: LoggingTransaction) -> List[_EventDictReturn]: + def f(txn: LoggingTransaction) -> Tuple[List[_EventDictReturn], bool]: sql = f""" SELECT event_id, instance_name, stream_ordering FROM events @@ -853,9 +902,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """ txn.execute(sql, (room_id, 2 * limit)) + # Get all the rows and check if we hit the limit. + fetched_rows = txn.fetchall() + limited = len(fetched_rows) >= 2 * limit + rows = [ _EventDictReturn(event_id, None, stream_ordering) - for event_id, instance_name, stream_ordering in txn + for event_id, instance_name, stream_ordering in fetched_rows if _filter_results_by_stream( lower_token=( to_key if direction == Direction.BACKWARDS else from_key @@ -866,10 +919,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): instance_name=instance_name, stream_ordering=stream_ordering, ) - ][:limit] - return rows + ] + + if len(rows) > limit: + limited = True - rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) + rows = rows[:limit] + return rows, limited + + rows, limited = await self.db_pool.runInteraction( + "get_room_events_stream_for_room", f + ) ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -886,7 +946,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # `_paginate_room_events_by_topological_ordering_txn(...)`) next_key = to_key if to_key else from_key - return ret, next_key + return ret, next_key, limited @trace async def get_current_state_delta_membership_changes_for_user( @@ -927,7 +987,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: All membership changes to the current state in the token range. Events are sorted by `stream_ordering` ascending. + + `event_id`/`sender` can be `None` when the server leaves a room (meaning + everyone locally left) or a state reset which removed the person from the + room. We can't tell the difference between the two cases with what's + available in the `current_state_delta_stream` table. To actually check for a + state reset, you need to check if a membership still exists in the room. """ + + assert from_key.topological is None + assert to_key.topological is None + # Start by ruling out cases where a DB query is not necessary. if from_key == to_key: return [] @@ -1038,6 +1108,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): membership=( membership if membership is not None else Membership.LEAVE ), + # This will also be null for the same reasons if `s.event_id = null` sender=sender, # Prev event prev_event_id=prev_event_id, @@ -1072,6 +1143,203 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if membership_change.room_id not in room_ids_to_exclude ] + @trace + async def get_sliding_sync_membership_changes( + self, + user_id: str, + from_key: RoomStreamToken, + to_key: RoomStreamToken, + excluded_room_ids: Optional[AbstractSet[str]] = None, + ) -> Dict[str, RoomsForUserStateReset]: + """ + Fetch membership events that result in a meaningful membership change for a + given user. + + A meaningful membership changes is one where the `membership` value actually + changes. This means memberships changes from `join` to `join` (like a display + name change) will be filtered out since they result in no meaningful change. + + Note: This function only works with "live" tokens with `stream_ordering` only. + + We're looking for membership changes in the token range (> `from_key` and <= + `to_key`). + + Args: + user_id: The user ID to fetch membership events for. + from_key: The point in the stream to sync from (fetching events > this point). + to_key: The token to fetch rooms up to (fetching events <= this point). + excluded_room_ids: Optional list of room IDs to exclude from the results. + + Returns: + All meaningful membership changes to the current state in the token range. + Events are sorted by `stream_ordering` ascending. + + `event_id`/`sender` can be `None` when the server leaves a room (meaning + everyone locally left) or a state reset which removed the person from the + room. We can't tell the difference between the two cases with what's + available in the `current_state_delta_stream` table. To actually check for a + state reset, you need to check if a membership still exists in the room. + """ + + assert from_key.topological is None + assert to_key.topological is None + + # Start by ruling out cases where a DB query is not necessary. + if from_key == to_key: + return {} + + if from_key: + has_changed = self._membership_stream_cache.has_entity_changed( + user_id, int(from_key.stream) + ) + if not has_changed: + return {} + + room_ids_to_exclude: AbstractSet[str] = set() + if excluded_room_ids is not None: + room_ids_to_exclude = excluded_room_ids + + def f(txn: LoggingTransaction) -> Dict[str, RoomsForUserStateReset]: + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + # This query looks at membership changes in + # `sliding_sync_membership_snapshots` which will not include users + # that were state reset out of rooms; so we need to look for that + # case in `current_state_delta_stream`. + sql = """ + SELECT + room_id, + membership_event_id, + event_instance_name, + event_stream_ordering, + membership, + sender, + prev_membership, + room_version + FROM + ( + SELECT + s.room_id, + s.membership_event_id, + s.event_instance_name, + s.event_stream_ordering, + s.membership, + s.sender, + m_prev.membership AS prev_membership + FROM sliding_sync_membership_snapshots as s + LEFT JOIN event_edges AS e ON e.event_id = s.membership_event_id + LEFT JOIN room_memberships AS m_prev ON m_prev.event_id = e.prev_event_id + WHERE s.user_id = ? + + UNION ALL + + SELECT + s.room_id, + e.event_id, + s.instance_name, + s.stream_id, + m.membership, + e.sender, + m_prev.membership AS prev_membership + FROM current_state_delta_stream AS s + LEFT JOIN events AS e ON e.event_id = s.event_id + LEFT JOIN room_memberships AS m ON m.event_id = s.event_id + LEFT JOIN room_memberships AS m_prev ON m_prev.event_id = s.prev_event_id + WHERE + s.type = ? + AND s.state_key = ? + ) AS c + INNER JOIN rooms USING (room_id) + WHERE event_stream_ordering > ? AND event_stream_ordering <= ? + ORDER BY event_stream_ordering ASC + """ + + txn.execute( + sql, + (user_id, EventTypes.Member, user_id, min_from_id, max_to_id), + ) + + membership_changes: Dict[str, RoomsForUserStateReset] = {} + for ( + room_id, + membership_event_id, + event_instance_name, + event_stream_ordering, + membership, + sender, + prev_membership, + room_version_id, + ) in txn: + assert room_id is not None + assert event_stream_ordering is not None + + if room_id in room_ids_to_exclude: + continue + + if _filter_results_by_stream( + from_key, + to_key, + event_instance_name, + event_stream_ordering, + ): + # When the server leaves a room, it will insert new rows into the + # `current_state_delta_stream` table with `event_id = null` for all + # current state. This means we might already have a row for the + # leave event and then another for the same leave where the + # `event_id=null` but the `prev_event_id` is pointing back at the + # earlier leave event. We don't want to report the leave, if we + # already have a leave event. + if ( + membership_event_id is None + and prev_membership == Membership.LEAVE + ): + continue + + if membership_event_id is None and room_id in membership_changes: + # SUSPICIOUS: if we join a room and get state reset out of it + # in the same queried window, + # won't this ignore the 'state reset out of it' part? + continue + + # When `s.event_id = null`, we won't be able to get respective + # `room_membership` but can assume the user has left the room + # because this only happens when the server leaves a room + # (meaning everyone locally left) or a state reset which removed + # the person from the room. + membership = ( + membership if membership is not None else Membership.LEAVE + ) + + if membership == prev_membership: + # If `membership` and `prev_membership` are the same then this + # is not a meaningful change so we can skip it. + # An example of this happening is when the user changes their display name. + continue + + membership_change = RoomsForUserStateReset( + room_id=room_id, + sender=sender, + membership=membership, + event_id=membership_event_id, + event_pos=PersistedEventPosition( + event_instance_name, event_stream_ordering + ), + room_version_id=room_version_id, + ) + + membership_changes[room_id] = membership_change + + return membership_changes + + membership_changes = await self.db_pool.runInteraction( + "get_sliding_sync_membership_changes", f + ) + + return membership_changes + @cancellable async def get_membership_changes_for_user( self, @@ -1121,9 +1389,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): AND e.stream_ordering > ? AND e.stream_ordering <= ? %s ORDER BY e.stream_ordering ASC - """ % ( - ignore_room_clause, - ) + """ % (ignore_room_clause,) txn.execute(sql, args) @@ -1192,7 +1458,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if limit == 0: return [], end_token - rows, token = await self.db_pool.runInteraction( + rows, token, _ = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_by_topological_ordering_txn, room_id, @@ -1263,12 +1529,76 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return None + async def get_last_event_pos_in_room( + self, + room_id: str, + event_types: Optional[StrCollection] = None, + ) -> Optional[Tuple[str, PersistedEventPosition]]: + """ + Returns the ID and event position of the last event in a room. + + Based on `get_last_event_pos_in_room_before_stream_ordering(...)` + + Args: + room_id + event_types: Optional allowlist of event types to filter by + + Returns: + The ID of the most recent event and it's position, or None if there are no + events in the room that match the given event types. + """ + + def _get_last_event_pos_in_room_txn( + txn: LoggingTransaction, + ) -> Optional[Tuple[str, PersistedEventPosition]]: + event_type_clause = "" + event_type_args: List[str] = [] + if event_types is not None and len(event_types) > 0: + event_type_clause, event_type_args = make_in_list_sql_clause( + txn.database_engine, "type", event_types + ) + event_type_clause = f"AND {event_type_clause}" + + sql = f""" + SELECT event_id, stream_ordering, instance_name + FROM events + LEFT JOIN rejections USING (event_id) + WHERE room_id = ? + {event_type_clause} + AND NOT outlier + AND rejections.event_id IS NULL + ORDER BY stream_ordering DESC + LIMIT 1 + """ + + txn.execute( + sql, + [room_id] + event_type_args, + ) + + row = cast(Optional[Tuple[str, int, str]], txn.fetchone()) + if row is not None: + event_id, stream_ordering, instance_name = row + + return event_id, PersistedEventPosition( + # If instance_name is null we default to "master" + instance_name or "master", + stream_ordering, + ) + + return None + + return await self.db_pool.runInteraction( + "get_last_event_pos_in_room", + _get_last_event_pos_in_room_txn, + ) + @trace async def get_last_event_pos_in_room_before_stream_ordering( self, room_id: str, end_token: RoomStreamToken, - event_types: Optional[Collection[str]] = None, + event_types: Optional[StrCollection] = None, ) -> Optional[Tuple[str, PersistedEventPosition]]: """ Returns the ID and event position of the last event in a room at or before a @@ -1381,8 +1711,56 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rooms """ + # First we just get the latest positions for the room, as the vast + # majority of them will be before the given end token anyway. By doing + # this we can cache most rooms. + uncapped_results = await self._bulk_get_max_event_pos(room_ids) + + # Check that the stream position for the rooms are from before the + # minimum position of the token. If not then we need to fetch more + # rows. + results: Dict[str, int] = {} + recheck_rooms: Set[str] = set() min_token = end_token.stream - max_token = end_token.get_max_stream_pos() + for room_id, stream in uncapped_results.items(): + if stream is None: + # Despite the function not directly setting None, the cache can! + # See: https://github.com/element-hq/synapse/issues/17726 + continue + if stream <= min_token: + results[room_id] = stream + else: + recheck_rooms.add(room_id) + + if not recheck_rooms: + return results + + # There shouldn't be many rooms that we need to recheck, so we do them + # one-by-one. + for room_id in recheck_rooms: + result = await self.get_last_event_pos_in_room_before_stream_ordering( + room_id, end_token + ) + if result is not None: + results[room_id] = result[1].stream + + return results + + @cached() + async def _get_max_event_pos(self, room_id: str) -> int: + raise NotImplementedError() + + @cachedList(cached_method_name="_get_max_event_pos", list_name="room_ids") + async def _bulk_get_max_event_pos( + self, room_ids: StrCollection + ) -> Mapping[str, Optional[int]]: + """Fetch the max position of a persisted event in the room.""" + + # We need to be careful not to return positions ahead of the current + # positions, so we get the current token now and cap our queries to it. + now_token = self.get_room_max_token() + max_pos = now_token.get_max_stream_pos() + results: Dict[str, int] = {} # First, we check for the rooms in the stream change cache to see if we @@ -1390,31 +1768,32 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): missing_room_ids: Set[str] = set() for room_id in room_ids: stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id) - if stream_pos and stream_pos <= min_token: + if stream_pos is not None: results[room_id] = stream_pos else: missing_room_ids.add(room_id) + if not missing_room_ids: + return results + # Next, we query the stream position from the DB. At first we fetch all # positions less than the *max* stream pos in the token, then filter # them down. We do this as a) this is a cheaper query, and b) the vast # majority of rooms will have a latest token from before the min stream # pos. - def bulk_get_last_event_pos_txn( - txn: LoggingTransaction, batch_room_ids: StrCollection + def bulk_get_max_event_pos_fallback_txn( + txn: LoggingTransaction, batched_room_ids: StrCollection ) -> Dict[str, int]: - # This query fetches the latest stream position in the rooms before - # the given max position. clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", batch_room_ids + self.database_engine, "room_id", batched_room_ids ) sql = f""" SELECT room_id, ( SELECT stream_ordering FROM events AS e LEFT JOIN rejections USING (event_id) WHERE e.room_id = r.room_id - AND stream_ordering <= ? + AND e.stream_ordering <= ? AND NOT outlier AND rejection_reason IS NULL ORDER BY stream_ordering DESC @@ -1423,72 +1802,55 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): FROM rooms AS r WHERE {clause} """ - txn.execute(sql, [max_token] + args) + txn.execute(sql, [max_pos] + args) return {row[0]: row[1] for row in txn} - recheck_rooms: Set[str] = set() - for batched in batch_iter(missing_room_ids, 1000): - result = await self.db_pool.runInteraction( - "bulk_get_last_event_pos_in_room_before_stream_ordering", - bulk_get_last_event_pos_txn, - batched, - ) - - # Check that the stream position for the rooms are from before the - # minimum position of the token. If not then we need to fetch more - # rows. - for room_id, stream in result.items(): - if stream <= min_token: - results[room_id] = stream - else: - recheck_rooms.add(room_id) - - if not recheck_rooms: - return results - - # For the remaining rooms we need to fetch all rows between the min and - # max stream positions in the end token, and filter out the rows that - # are after the end token. - # - # This query should be fast as the range between the min and max should - # be small. - - def bulk_get_last_event_pos_recheck_txn( - txn: LoggingTransaction, batch_room_ids: StrCollection + # It's easier to look at the `sliding_sync_joined_rooms` table and avoid all of + # the joins and sub-queries. + def bulk_get_max_event_pos_from_sliding_sync_tables_txn( + txn: LoggingTransaction, batched_room_ids: StrCollection ) -> Dict[str, int]: clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", batch_room_ids + self.database_engine, "room_id", batched_room_ids ) sql = f""" - SELECT room_id, instance_name, stream_ordering - FROM events - WHERE ? < stream_ordering AND stream_ordering <= ? - AND NOT outlier - AND rejection_reason IS NULL - AND {clause} - ORDER BY stream_ordering ASC + SELECT room_id, event_stream_ordering + FROM sliding_sync_joined_rooms + WHERE {clause} + ORDER BY event_stream_ordering DESC """ - txn.execute(sql, [min_token, max_token] + args) - - # We take the max stream ordering that is less than the token. Since - # we ordered by stream ordering we just need to iterate through and - # take the last matching stream ordering. - txn_results: Dict[str, int] = {} - for row in txn: - room_id = row[0] - event_pos = PersistedEventPosition(row[1], row[2]) - if not event_pos.persisted_after(end_token): - txn_results[room_id] = event_pos.stream - - return txn_results - - for batched in batch_iter(recheck_rooms, 1000): - recheck_result = await self.db_pool.runInteraction( - "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck", - bulk_get_last_event_pos_recheck_txn, - batched, + txn.execute(sql, args) + return {row[0]: row[1] for row in txn} + + recheck_rooms: Set[str] = set() + for batched in batch_iter(room_ids, 1000): + if await self.have_finished_sliding_sync_background_jobs(): + batch_results = await self.db_pool.runInteraction( + "bulk_get_max_event_pos_from_sliding_sync_tables_txn", + bulk_get_max_event_pos_from_sliding_sync_tables_txn, + batched, + ) + else: + batch_results = await self.db_pool.runInteraction( + "bulk_get_max_event_pos_fallback_txn", + bulk_get_max_event_pos_fallback_txn, + batched, + ) + for room_id, stream_ordering in batch_results.items(): + if stream_ordering <= now_token.stream: + results[room_id] = stream_ordering + else: + recheck_rooms.add(room_id) + + # We now need to handle rooms where the above query returned a stream + # position that was potentially too new. This should happen very rarely + # so we just query the rooms one-by-one + for room_id in recheck_rooms: + result = await self.get_last_event_pos_in_room_before_stream_ordering( + room_id, now_token ) - results.update(recheck_result) + if result is not None: + results[room_id] = result[1].stream return results @@ -1680,15 +2042,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - stream_ordering, topological_ordering = cast( - Tuple[int, int], - self.db_pool.simple_select_one_txn( - txn, - "events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcols=["stream_ordering", "topological_ordering"], - ), + row = self.db_pool.simple_select_one_txn( + txn, + "events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcols=("stream_ordering", "topological_ordering"), ) + stream_ordering = int(row[0]) + topological_ordering = int(row[1]) # Paginating backwards includes the event at the token, but paginating # forward doesn't. @@ -1700,7 +2061,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): topological=topological_ordering, stream=stream_ordering ) - rows, start_token = self._paginate_room_events_by_topological_ordering_txn( + rows, start_token, _ = self._paginate_room_events_by_topological_ordering_txn( txn, room_id, before_token, @@ -1710,7 +2071,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) events_before = [r.event_id for r in rows] - rows, end_token = self._paginate_room_events_by_topological_ordering_txn( + rows, end_token, _ = self._paginate_room_events_by_topological_ordering_txn( txn, room_id, after_token, @@ -1882,7 +2243,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): direction: Direction = Direction.BACKWARDS, limit: int = 0, event_filter: Optional[Filter] = None, - ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: + ) -> Tuple[List[_EventDictReturn], RoomStreamToken, bool]: """Returns list of events before or after a given token. Args: @@ -1897,10 +2258,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): those that match the filter. Returns: - A list of _EventDictReturn and a token that points to the end of the - result set. If no events are returned then the end of the stream has - been reached (i.e. there are no events between `from_token` and - `to_token`), or `limit` is zero. + A list of _EventDictReturn, a token that points to the end of the + result set, and a boolean to indicate if there were more events but + we hit the limit. If no events are returned then the end of the + stream has been reached (i.e. there are no events between + `from_token` and `to_token`), or `limit` is zero. """ # We can bail early if we're looking forwards, and our `to_key` is already # before our `from_token`. @@ -1910,7 +2272,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and to_token.is_before_or_eq(from_token) ): # Token selection matches what we do below if there are no rows - return [], to_token if to_token else from_token + return [], to_token if to_token else from_token, False # Or vice-versa, if we're looking backwards and our `from_token` is already before # our `to_token`. elif ( @@ -1919,7 +2281,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and from_token.is_before_or_eq(to_token) ): # Token selection matches what we do below if there are no rows - return [], to_token if to_token else from_token + return [], to_token if to_token else from_token, False args: List[Any] = [room_id] @@ -1942,6 +2304,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): args.extend(filter_args) # We fetch more events as we'll filter the result set + requested_limit = int(limit) * 2 args.append(int(limit) * 2) select_keywords = "SELECT" @@ -2006,10 +2369,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): } txn.execute(sql, args) + # Get all the rows and check if we hit the limit. + fetched_rows = txn.fetchall() + limited = len(fetched_rows) >= requested_limit + # Filter the result set. rows = [ _EventDictReturn(event_id, topological_ordering, stream_ordering) - for event_id, instance_name, topological_ordering, stream_ordering in txn + for event_id, instance_name, topological_ordering, stream_ordering in fetched_rows if _filter_results( lower_token=( to_token if direction == Direction.BACKWARDS else from_token @@ -2021,7 +2388,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): topological_ordering=topological_ordering, stream_ordering=stream_ordering, ) - ][:limit] + ] + + if len(rows) > limit: + limited = True + + rows = rows[:limit] if rows: assert rows[-1].topological_ordering is not None @@ -2032,7 +2404,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token - return rows, next_token + return rows, next_token, limited @trace @tag_args @@ -2045,7 +2417,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): direction: Direction = Direction.BACKWARDS, limit: int = 0, event_filter: Optional[Filter] = None, - ) -> Tuple[List[EventBase], RoomStreamToken]: + ) -> Tuple[List[EventBase], RoomStreamToken, bool]: """ Paginate events by `topological_ordering` (tie-break with `stream_ordering`) in the room from the `from_key` in the given `direction` to the `to_key` or @@ -2062,8 +2434,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter: If provided filters the events to those that match the filter. Returns: - The results as a list of events and a token that points to the end - of the result set. If no events are returned then the end of the + The results as a list of events, a token that points to the end of + the result set, and a boolean to indicate if there were more events + but we hit the limit. If no events are returned then the end of the stream has been reached (i.e. there are no events between `from_key` and `to_key`). @@ -2087,7 +2460,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ): # Token selection matches what we do in `_paginate_room_events_txn` if there # are no rows - return [], to_key if to_key else from_key + return [], to_key if to_key else from_key, False # Or vice-versa, if we're looking backwards and our `from_key` is already before # our `to_key`. elif ( @@ -2097,9 +2470,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ): # Token selection matches what we do in `_paginate_room_events_txn` if there # are no rows - return [], to_key if to_key else from_key + return [], to_key if to_key else from_key, False - rows, token = await self.db_pool.runInteraction( + rows, token, limited = await self.db_pool.runInteraction( "paginate_room_events_by_topological_ordering", self._paginate_room_events_by_topological_ordering_txn, room_id, @@ -2114,7 +2487,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): [r.event_id for r in rows], get_prev_content=True ) - return events, token + return events, token, limited @cached() async def get_id_for_instance(self, instance_name: str) -> int: diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index b5af294384..97b190bccc 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py
@@ -158,9 +158,56 @@ class TagsWorkerStore(AccountDataWorkerStore): return results + async def has_tags_changed_for_room( + self, + # Since there are multiple arguments with the same type, force keyword arguments + # so people don't accidentally swap the order + *, + user_id: str, + room_id: str, + from_stream_id: int, + to_stream_id: int, + ) -> bool: + """Check if the users tags for a room have been updated in the token range + + (> `from_stream_id` and <= `to_stream_id`) + + Args: + user_id: The user to get tags for + room_id: The room to get tags for + from_stream_id: The point in the stream to fetch from + to_stream_id: The point in the stream to fetch to + + Returns: + A mapping of tags to tag content. + """ + + # Shortcut if no room has changed for the user + changed = self._account_data_stream_cache.has_entity_changed( + user_id, int(from_stream_id) + ) + if not changed: + return False + + last_change_position_for_room = await self.db_pool.simple_select_one_onecol( + table="room_tags_revisions", + keyvalues={"user_id": user_id, "room_id": room_id}, + retcol="stream_id", + allow_none=True, + ) + + if last_change_position_for_room is None: + return False + + return ( + last_change_position_for_room > from_stream_id + and last_change_position_for_room <= to_stream_id + ) + + @cached(num_args=2, tree=True) async def get_tags_for_room( self, user_id: str, room_id: str - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonMapping]: """Get all the tags for the given room Args: @@ -182,7 +229,7 @@ class TagsWorkerStore(AccountDataWorkerStore): return {tag: db_to_json(content) for tag, content in rows} async def add_tag_to_room( - self, user_id: str, room_id: str, tag: str, content: JsonDict + self, user_id: str, room_id: str, tag: str, content: JsonMapping ) -> int: """Add a tag to a room for a user. @@ -213,6 +260,7 @@ class TagsWorkerStore(AccountDataWorkerStore): await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) + self.get_tags_for_room.invalidate((user_id, room_id)) return self._account_data_id_gen.get_current_token() @@ -226,10 +274,7 @@ class TagsWorkerStore(AccountDataWorkerStore): assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: - sql = ( - "DELETE FROM room_tags " - " WHERE user_id = ? AND room_id = ? AND tag = ?" - ) + sql = "DELETE FROM room_tags WHERE user_id = ? AND room_id = ? AND tag = ?" txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) @@ -237,6 +282,7 @@ class TagsWorkerStore(AccountDataWorkerStore): await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) + self.get_tags_for_room.invalidate((user_id, room_id)) return self._account_data_id_gen.get_current_token() @@ -290,9 +336,19 @@ class TagsWorkerStore(AccountDataWorkerStore): rows: Iterable[Any], ) -> None: if stream_name == AccountDataStream.NAME: - for row in rows: + # Cast is safe because the `AccountDataStream` should only be giving us + # `AccountDataStreamRow` + account_data_stream_rows: List[AccountDataStream.AccountDataStreamRow] = ( + cast(List[AccountDataStream.AccountDataStreamRow], rows) + ) + + for row in account_data_stream_rows: if row.data_type == AccountDataTypes.TAG: self.get_tags_for_user.invalidate((row.user_id,)) + if row.room_id: + self.get_tags_for_room.invalidate((row.user_id, row.room_id)) + else: + self.get_tags_for_room.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed( row.user_id, token ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 770802483c..bfc324b80d 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -86,10 +86,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): @wrap_as_background_process("cleanup_transactions") async def _cleanup_transactions(self) -> None: now = self._clock.time_msec() - month_ago = now - 30 * 24 * 60 * 60 * 1000 + day_ago = now - 24 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn: LoggingTransaction) -> None: - txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + txn.execute("DELETE FROM received_transactions WHERE ts < ?", (day_ago,)) await self.db_pool.runInteraction( "_cleanup_transactions", _cleanup_transactions_txn diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 6e18f714d7..31a8ce6666 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -31,6 +31,7 @@ from typing import ( Sequence, Set, Tuple, + TypedDict, cast, ) @@ -42,10 +43,9 @@ try: USE_ICU = True except ModuleNotFoundError: + # except ModuleNotFoundError: USE_ICU = False -from typing_extensions import TypedDict - from synapse.api.errors import StoreError from synapse.util.stringutils import non_null_str_or_none @@ -224,9 +224,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): SELECT room_id, events FROM %s ORDER BY events DESC LIMIT 250 - """ % ( - TEMP_TABLE + "_rooms", - ) + """ % (TEMP_TABLE + "_rooms",) txn.execute(sql) rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall()) @@ -585,9 +583,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): retry_counter: number of failures in refreshing the profile so far. Used for exponential backoff calculations. """ - assert not self.hs.is_mine_id( - user_id - ), "Can't mark a local user as a stale remote user." + assert not self.hs.is_mine_id(user_id), ( + "Can't mark a local user as a stale remote user." + ) server_name = UserID.from_string(user_id).domain @@ -1040,11 +1038,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): } """ + join_args: Tuple[str, ...] = (user_id,) + if self.hs.config.userdirectory.user_directory_search_all_users: - join_args = (user_id,) where_clause = "user_id != ?" else: - join_args = (user_id,) where_clause = """ ( EXISTS (select 1 from users_in_public_rooms WHERE user_id = t.user_id) @@ -1058,6 +1056,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): if not show_locked_users: where_clause += " AND (u.locked IS NULL OR u.locked = FALSE)" + # Adjust the JOIN type based on the exclude_remote_users flag (the users + # table only contains local users so an inner join is a good way to + # to exclude remote users) + if self.hs.config.userdirectory.user_directory_exclude_remote_users: + join_type = "JOIN" + else: + join_type = "LEFT JOIN" + # We allow manipulating the ranking algorithm by injecting statements # based on config options. additional_ordering_statements = [] @@ -1089,7 +1095,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SELECT d.user_id AS user_id, display_name, avatar_url FROM matching_users as t INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users AS u ON t.user_id = u.name + %(join_type)s users AS u ON t.user_id = u.name WHERE %(where_clause)s ORDER BY @@ -1118,6 +1124,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """ % { "where_clause": where_clause, "order_case_statements": " ".join(additional_ordering_statements), + "join_type": join_type, } args = ( (full_query,) @@ -1145,7 +1152,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users AS u ON t.user_id = u.name + %(join_type)s users AS u ON t.user_id = u.name WHERE %(where_clause)s AND value MATCH ? @@ -1158,6 +1165,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """ % { "where_clause": where_clause, "order_statements": " ".join(additional_ordering_statements), + "join_type": join_type, } args = join_args + (search_query,) + ordering_arguments + (limit + 1,) else: @@ -1240,7 +1248,13 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: search_term = _filter_text_for_index(search_term) escaped_words = [] - for word in _parse_words(search_term): + for index, word in enumerate(_parse_words(search_term)): + if index >= 10: + # We limit how many terms we include, as otherwise it can use + # excessive database time if people accidentally search for large + # strings. + break + # Postgres tsvector and tsquery quoting rules: # words potentially containing punctuation should be quoted # and then existing quotes and backslashes should be doubled diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index ea7d8199a7..5b594fe8dd 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py
@@ -20,7 +20,15 @@ # import logging -from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, +) from synapse.logging.opentracing import tag_args, trace from synapse.storage._base import SQLBaseStore @@ -112,8 +120,8 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): Returns: Map from state_group to a StateMap at that point. """ - - state_filter = state_filter or StateFilter.all() + if state_filter is None: + state_filter = StateFilter.all() results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} @@ -388,8 +396,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): return True, count txn.execute( - "SELECT state_group FROM state_group_edges" - " WHERE state_group = ?", + "SELECT state_group FROM state_group_edges WHERE state_group = ?", (state_group,), ) diff --git a/synapse/storage/databases/state/deletion.py b/synapse/storage/databases/state/deletion.py new file mode 100644
index 0000000000..f77c46f6ae --- /dev/null +++ b/synapse/storage/databases/state/deletion.py
@@ -0,0 +1,561 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + + +import contextlib +from typing import ( + TYPE_CHECKING, + AbstractSet, + AsyncIterator, + Collection, + Mapping, + Optional, + Set, + Tuple, +) + +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.storage.engines import PostgresEngine +from synapse.util.stringutils import shortstr + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class StateDeletionDataStore: + """Manages deletion of state groups in a safe manner. + + Deleting state groups is challenging as before we actually delete them we + need to ensure that there are no in-flight events that refer to the state + groups that we want to delete. + + To handle this, we take two approaches. First, before we persist any event + we ensure that the state group still exists and mark in the + `state_groups_persisting` table that the state group is about to be used. + (Note that we have to have the extra table here as state groups and events + can be in different databases, and thus we can't check for the existence of + state groups in the persist event transaction). Once the event has been + persisted, we can remove the row from `state_groups_persisting`. So long as + we check that table before deleting state groups, we can ensure that we + never persist events that reference deleted state groups, maintaining + database integrity. + + However, we want to avoid throwing exceptions so deep in the process of + persisting events. So instead of deleting state groups immediately, we mark + them as pending/proposed for deletion and wait for a certain amount of time + before performing the deletion. When we come to handle new events that + reference state groups, we check if they are pending deletion and bump the + time for when they'll be deleted (to give a chance for the event to be + persisted, or not). + + When deleting, we need to check that state groups remain unreferenced. There + is a race here where we a) fetch state groups that are ready for deletion, + b) check they're unreferenced, c) the state group becomes referenced but + then gets marked as pending deletion again, d) during the deletion + transaction we recheck `state_groups_pending_deletion` table again and see + that it exists and so continue with the deletion. To prevent this from + happening we add a `sequence_number` column to + `state_groups_pending_deletion`, and during deletion we ensure that for a + state group we're about to delete that the sequence number doesn't change + between steps (a) and (d). So long as we always bump the sequence number + whenever an event may become used the race can never happen. + """ + + # How long to wait before we delete state groups. This should be long enough + # for any in-flight events to be persisted. If events take longer to persist + # and any of the state groups they reference have been deleted, then the + # event will fail to persist (as well as any event in the same batch). + DELAY_BEFORE_DELETION_MS = 10 * 60 * 1000 + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + self._clock = hs.get_clock() + self.db_pool = database + self._instance_name = hs.get_instance_name() + + with db_conn.cursor(txn_name="_clear_existing_persising") as txn: + self._clear_existing_persising(txn) + + def _clear_existing_persising(self, txn: LoggingTransaction) -> None: + """On startup we clear any entries in `state_groups_persisting` that + match our instance name, in case of a previous unclean shutdown""" + + self.db_pool.simple_delete_txn( + txn, + table="state_groups_persisting", + keyvalues={"instance_name": self._instance_name}, + ) + + async def check_state_groups_and_bump_deletion( + self, state_groups: AbstractSet[int] + ) -> Collection[int]: + """Checks to make sure that the state groups haven't been deleted, and + if they're pending deletion we delay it (allowing time for any event + that will use them to finish persisting). + + Returns: + The state groups that are missing, if any. + """ + + return await self.db_pool.runInteraction( + "check_state_groups_and_bump_deletion", + self._check_state_groups_and_bump_deletion_txn, + state_groups, + # We don't need to lock if we're just doing a quick check, as the + # lock doesn't prevent any races here. + lock=False, + ) + + def _check_state_groups_and_bump_deletion_txn( + self, txn: LoggingTransaction, state_groups: AbstractSet[int], lock: bool = True + ) -> Collection[int]: + """Checks to make sure that the state groups haven't been deleted, and + if they're pending deletion we delay it (allowing time for any event + that will use them to finish persisting). + + The `lock` flag sets if we should lock the `state_group` rows we're + checking, which we should do when storing new groups. + + Returns: + The state groups that are missing, if any. + """ + + existing_state_groups = self._get_existing_groups_with_lock( + txn, state_groups, lock=lock + ) + + self._bump_deletion_txn(txn, existing_state_groups) + + missing_state_groups = state_groups - existing_state_groups + if missing_state_groups: + return missing_state_groups + + return () + + def _bump_deletion_txn( + self, txn: LoggingTransaction, state_groups: Collection[int] + ) -> None: + """Update any pending deletions of the state group that they may now be + referenced.""" + + if not state_groups: + return + + now = self._clock.time_msec() + if isinstance(self.db_pool.engine, PostgresEngine): + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "state_group", state_groups + ) + sql = f""" + UPDATE state_groups_pending_deletion + SET sequence_number = DEFAULT, insertion_ts = ? + WHERE {clause} + """ + args.insert(0, now) + txn.execute(sql, args) + else: + rows = self.db_pool.simple_select_many_txn( + txn, + table="state_groups_pending_deletion", + column="state_group", + iterable=state_groups, + keyvalues={}, + retcols=("state_group",), + ) + if not rows: + return + + state_groups_to_update = [state_group for (state_group,) in rows] + + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups_pending_deletion", + column="state_group", + values=state_groups_to_update, + keyvalues={}, + ) + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_pending_deletion", + keys=("state_group", "insertion_ts"), + values=[(state_group, now) for state_group in state_groups_to_update], + ) + + def _get_existing_groups_with_lock( + self, txn: LoggingTransaction, state_groups: Collection[int], lock: bool = True + ) -> AbstractSet[int]: + """Return which of the given state groups are in the database, and locks + those rows with `KEY SHARE` to ensure they don't get concurrently + deleted (if `lock` is true).""" + clause, args = make_in_list_sql_clause(self.db_pool.engine, "id", state_groups) + + sql = f""" + SELECT id FROM state_groups + WHERE {clause} + """ + if lock and isinstance(self.db_pool.engine, PostgresEngine): + # On postgres we add a row level lock to the rows to ensure that we + # conflict with any concurrent DELETEs. `FOR KEY SHARE` lock will + # not conflict with other read + sql += """ + FOR KEY SHARE + """ + + txn.execute(sql, args) + return {state_group for (state_group,) in txn} + + @contextlib.asynccontextmanager + async def persisting_state_group_references( + self, event_and_contexts: Collection[Tuple[EventBase, EventContext]] + ) -> AsyncIterator[None]: + """Wraps the persistence of the given events and contexts, ensuring that + any state groups referenced still exist and that they don't get deleted + during this.""" + + referenced_state_groups: Set[int] = set() + for event, ctx in event_and_contexts: + if ctx.rejected or event.internal_metadata.is_outlier(): + continue + + assert ctx.state_group is not None + + referenced_state_groups.add(ctx.state_group) + + if ctx.state_group_before_event: + referenced_state_groups.add(ctx.state_group_before_event) + + if not referenced_state_groups: + # We don't reference any state groups, so nothing to do + yield + return + + await self.db_pool.runInteraction( + "mark_state_groups_as_persisting", + self._mark_state_groups_as_persisting_txn, + referenced_state_groups, + ) + + error = True + try: + yield None + error = False + finally: + await self.db_pool.runInteraction( + "finish_persisting", + self._finish_persisting_txn, + referenced_state_groups, + error=error, + ) + + def _mark_state_groups_as_persisting_txn( + self, txn: LoggingTransaction, state_groups: Set[int] + ) -> None: + """Marks the given state groups as being persisted.""" + + existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups) + missing_state_groups = state_groups - existing_state_groups + if missing_state_groups: + raise Exception( + f"state groups have been deleted: {shortstr(missing_state_groups)}" + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_persisting", + keys=("state_group", "instance_name"), + values=[(state_group, self._instance_name) for state_group in state_groups], + ) + + def _finish_persisting_txn( + self, txn: LoggingTransaction, state_groups: Collection[int], error: bool + ) -> None: + """Mark the state groups as having finished persistence. + + If `error` is true then we assume the state groups were not persisted, + and so we do not clear them from the pending deletion table. + """ + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups_persisting", + column="state_group", + values=state_groups, + keyvalues={"instance_name": self._instance_name}, + ) + + if error: + # The state groups may or may not have been persisted, so we need to + # bump the deletion to ensure we recheck if they have become + # referenced. + self._bump_deletion_txn(txn, state_groups) + return + + self.db_pool.simple_delete_many_batch_txn( + txn, + table="state_groups_pending_deletion", + keys=("state_group",), + values=[(state_group,) for state_group in state_groups], + ) + + async def mark_state_groups_as_pending_deletion( + self, state_groups: Collection[int] + ) -> None: + """Mark the given state groups as pending deletion. + + If any of the state groups are already pending deletion, then those records are + left as is. + """ + + await self.db_pool.runInteraction( + "mark_state_groups_as_pending_deletion", + self._mark_state_groups_as_pending_deletion_txn, + state_groups, + ) + + def _mark_state_groups_as_pending_deletion_txn( + self, + txn: LoggingTransaction, + state_groups: Collection[int], + ) -> None: + sql = """ + INSERT INTO state_groups_pending_deletion (state_group, insertion_ts) + VALUES %s + ON CONFLICT (state_group) + DO NOTHING + """ + + now = self._clock.time_msec() + rows = [ + ( + state_group, + now, + ) + for state_group in state_groups + ] + if isinstance(txn.database_engine, PostgresEngine): + txn.execute_values(sql % ("?",), rows, fetch=False) + else: + txn.execute_batch(sql % ("(?, ?)",), rows) + + async def mark_state_groups_as_used(self, state_groups: Collection[int]) -> None: + """Mark the given state groups as now being referenced""" + + await self.db_pool.simple_delete_many( + table="state_groups_pending_deletion", + column="state_group", + iterable=state_groups, + keyvalues={}, + desc="mark_state_groups_as_used", + ) + + async def get_pending_deletions( + self, state_groups: Collection[int] + ) -> Mapping[int, int]: + """Get which state groups are pending deletion. + + Returns: + a mapping from state groups that are pending deletion to their + sequence number + """ + + rows = await self.db_pool.simple_select_many_batch( + table="state_groups_pending_deletion", + column="state_group", + iterable=state_groups, + retcols=("state_group", "sequence_number"), + keyvalues={}, + desc="get_pending_deletions", + ) + + return dict(rows) + + def get_state_groups_ready_for_potential_deletion_txn( + self, + txn: LoggingTransaction, + state_groups_to_sequence_numbers: Mapping[int, int], + ) -> Collection[int]: + """Given a set of state groups, return which state groups can + potentially be deleted. + + The state groups must have been checked to see if they remain + unreferenced before calling this function. + + Note: This must be called within the same transaction that the state + groups are deleted. + + Args: + state_groups_to_sequence_numbers: The state groups, and the sequence + numbers from before the state groups were checked to see if they + were unreferenced. + + Returns: + The subset of state groups that can safely be deleted + + """ + + if not state_groups_to_sequence_numbers: + return state_groups_to_sequence_numbers + + if isinstance(self.db_pool.engine, PostgresEngine): + # On postgres we want to lock the rows FOR UPDATE as early as + # possible to help conflicts. + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "id", state_groups_to_sequence_numbers + ) + sql = f""" + SELECT id FROM state_groups + WHERE {clause} + FOR UPDATE + """ + txn.execute(sql, args) + + # Check the deletion status in the DB of the given state groups + clause, args = make_in_list_sql_clause( + self.db_pool.engine, + column="state_group", + iterable=state_groups_to_sequence_numbers, + ) + + sql = f""" + SELECT state_group, insertion_ts, sequence_number FROM ( + SELECT state_group, insertion_ts, sequence_number FROM state_groups_pending_deletion + UNION + SELECT state_group, null, null FROM state_groups_persisting + ) AS s + WHERE {clause} + """ + + txn.execute(sql, args) + + # The above query will return potentially two rows per state group (one + # for each table), so we track which state groups have enough time + # elapsed and which are not ready to be persisted. + ready_to_be_deleted = set() + not_ready_to_be_deleted = set() + + now = self._clock.time_msec() + for state_group, insertion_ts, sequence_number in txn: + if insertion_ts is None: + # A null insertion_ts means that we are currently persisting + # events that reference the state group, so we don't delete + # them. + not_ready_to_be_deleted.add(state_group) + continue + + # We know this can't be None if insertion_ts is not None + assert sequence_number is not None + + # Check if the sequence number has changed, if it has then it + # indicates that the state group may have become referenced since we + # checked. + if state_groups_to_sequence_numbers[state_group] != sequence_number: + not_ready_to_be_deleted.add(state_group) + continue + + if now - insertion_ts < self.DELAY_BEFORE_DELETION_MS: + # Not enough time has elapsed to allow us to delete. + not_ready_to_be_deleted.add(state_group) + continue + + ready_to_be_deleted.add(state_group) + + can_be_deleted = ready_to_be_deleted - not_ready_to_be_deleted + if not_ready_to_be_deleted: + # If there are any state groups that aren't ready to be deleted, + # then we also need to remove any state groups that are referenced + # by them. + clause, args = make_in_list_sql_clause( + self.db_pool.engine, + column="state_group", + iterable=state_groups_to_sequence_numbers, + ) + sql = f""" + WITH RECURSIVE ancestors(state_group) AS ( + SELECT DISTINCT prev_state_group + FROM state_group_edges WHERE {clause} + UNION + SELECT prev_state_group + FROM state_group_edges + INNER JOIN ancestors USING (state_group) + ) + SELECT state_group FROM ancestors + """ + txn.execute(sql, args) + + can_be_deleted.difference_update(state_group for (state_group,) in txn) + + return can_be_deleted + + async def get_next_state_group_collection_to_delete( + self, + ) -> Optional[Tuple[str, Mapping[int, int]]]: + """Get the next set of state groups to try and delete + + Returns: + 2-tuple of room_id and mapping of state groups to sequence number. + """ + return await self.db_pool.runInteraction( + "get_next_state_group_collection_to_delete", + self._get_next_state_group_collection_to_delete_txn, + ) + + def _get_next_state_group_collection_to_delete_txn( + self, + txn: LoggingTransaction, + ) -> Optional[Tuple[str, Mapping[int, int]]]: + """Implementation of `get_next_state_group_collection_to_delete`""" + + # We want to return chunks of state groups that were marked for deletion + # at the same time (this isn't necessary, just more efficient). We do + # this by looking for the oldest insertion_ts, and then pulling out all + # rows that have the same insertion_ts (and room ID). + now = self._clock.time_msec() + + sql = """ + SELECT room_id, insertion_ts + FROM state_groups_pending_deletion AS sd + INNER JOIN state_groups AS sg ON (id = sd.state_group) + LEFT JOIN state_groups_persisting AS sp USING (state_group) + WHERE insertion_ts < ? AND sp.state_group IS NULL + ORDER BY insertion_ts + LIMIT 1 + """ + txn.execute(sql, (now - self.DELAY_BEFORE_DELETION_MS,)) + row = txn.fetchone() + if not row: + return None + + (room_id, insertion_ts) = row + + sql = """ + SELECT state_group, sequence_number + FROM state_groups_pending_deletion AS sd + INNER JOIN state_groups AS sg ON (id = sd.state_group) + LEFT JOIN state_groups_persisting AS sp USING (state_group) + WHERE room_id = ? AND insertion_ts = ? AND sp.state_group IS NULL + ORDER BY insertion_ts + """ + txn.execute(sql, (room_id, insertion_ts)) + + return room_id, dict(txn) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index d4ac74c1ee..c1a66dcba0 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py
@@ -22,10 +22,10 @@ import logging from typing import ( TYPE_CHECKING, - Collection, Dict, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -36,7 +36,10 @@ import attr from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase +from synapse.events.snapshot import ( + UnpersistedEventContext, + UnpersistedEventContextBase, +) from synapse.logging.opentracing import tag_args, trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -45,6 +48,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap @@ -55,6 +59,7 @@ from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.state.deletion import StateDeletionDataStore logger = logging.getLogger(__name__) @@ -83,8 +88,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", + state_deletion_store: "StateDeletionDataStore", ): super().__init__(database, db_conn, hs) + self._state_deletion_store = state_deletion_store # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -284,7 +291,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): Returns: Dict of state group to state map. """ - state_filter = state_filter or StateFilter.all() + if state_filter is None: + state_filter = StateFilter.all() member_filter, non_member_filter = state_filter.get_member_split() @@ -466,14 +474,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): Returns: A list of state groups """ - is_in_db = self.db_pool.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, + + # We need to check that the prev group isn't about to be deleted + is_missing = ( + self._state_deletion_store._check_state_groups_and_bump_deletion_txn( + txn, + {prev_group}, + ) ) - if not is_in_db: + if is_missing: raise Exception( "Trying to persist state with unpersisted prev_group: %r" % (prev_group,) @@ -545,6 +554,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): for key, state_id in context.state_delta_due_to_event.items() ], ) + return events_and_context return await self.db_pool.runInteraction( @@ -600,14 +610,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): The state group if successfully created, or None if the state needs to be persisted as a full state. """ - is_in_db = self.db_pool.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, + + # We need to check that the prev group isn't about to be deleted + is_missing = ( + self._state_deletion_store._check_state_groups_and_bump_deletion_txn( + txn, + {prev_group}, + ) ) - if not is_in_db: + if is_missing: raise Exception( "Trying to persist state with unpersisted prev_group: %r" % (prev_group,) @@ -725,8 +736,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) async def purge_unreferenced_state_groups( - self, room_id: str, state_groups_to_delete: Collection[int] - ) -> None: + self, + room_id: str, + state_groups_to_sequence_numbers: Mapping[int, int], + ) -> bool: """Deletes no longer referenced state groups and de-deltas any state groups that reference them. @@ -734,21 +747,31 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): room_id: The room the state groups belong to (must all be in the same room). state_groups_to_delete: Set of all state groups to delete. + + Returns: + Whether any state groups were actually deleted. """ - await self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, - state_groups_to_delete, + state_groups_to_sequence_numbers, ) def _purge_unreferenced_state_groups( self, txn: LoggingTransaction, room_id: str, - state_groups_to_delete: Collection[int], - ) -> None: + state_groups_to_sequence_numbers: Mapping[int, int], + ) -> bool: + state_groups_to_delete = self._state_deletion_store.get_state_groups_ready_for_potential_deletion_txn( + txn, state_groups_to_sequence_numbers + ) + + if not state_groups_to_delete: + return False + logger.info( "[purge] found %i state groups to delete", len(state_groups_to_delete) ) @@ -767,7 +790,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): remaining_state_groups = { state_group - for state_group, in rows + for (state_group,) in rows if state_group not in state_groups_to_delete } @@ -804,13 +827,23 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): logger.info("[purge] removing redundant state groups") txn.execute_batch( "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), + [(sg,) for sg in state_groups_to_delete], + ) + txn.execute_batch( + "DELETE FROM state_group_edges WHERE state_group = ?", + [(sg,) for sg in state_groups_to_delete], ) txn.execute_batch( "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), + [(sg,) for sg in state_groups_to_delete], + ) + txn.execute_batch( + "DELETE FROM state_groups_pending_deletion WHERE state_group = ?", + [(sg,) for sg in state_groups_to_delete], ) + return True + @trace @tag_args async def get_previous_state_groups( @@ -829,7 +862,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): List[Tuple[int, int]], await self.db_pool.simple_select_many_batch( table="state_group_edges", - column="prev_state_group", + column="state_group", iterable=state_groups, keyvalues={}, retcols=("state_group", "prev_state_group"), @@ -839,60 +872,77 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return dict(rows) - async def purge_room_state( - self, room_id: str, state_groups_to_delete: Collection[int] - ) -> None: - """Deletes all record of a room from state tables + @trace + @tag_args + async def get_next_state_groups( + self, state_groups: Iterable[int] + ) -> Dict[int, int]: + """Fetch the groups that have the given state groups as their previous + state groups. Args: - room_id: - state_groups_to_delete: State groups to delete + state_groups + + Returns: + A mapping from state group to previous state group. """ - logger.info("[purge] Starting state purge") - await self.db_pool.runInteraction( + rows = cast( + List[Tuple[int, int]], + await self.db_pool.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("state_group", "prev_state_group"), + desc="get_next_state_groups", + ), + ) + + return dict(rows) + + async def purge_room_state(self, room_id: str) -> None: + return await self.db_pool.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, - state_groups_to_delete, ) - logger.info("[purge] Done with state purge") def _purge_room_state_txn( self, txn: LoggingTransaction, room_id: str, - state_groups_to_delete: Collection[int], ) -> None: - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) + # Delete all edges that reference a state group linked to room_id + logger.info("[purge] removing %s from state_group_edges", room_id) - self.db_pool.simple_delete_many_txn( - txn, - table="state_groups_state", - column="state_group", - values=state_groups_to_delete, - keyvalues={}, - ) + if isinstance(self.database_engine, PostgresEngine): + # Disable statement timeouts for this transaction; purging rooms can + # take a while! + txn.execute("SET LOCAL statement_timeout = 0") - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) + txn.execute( + """ + DELETE FROM state_group_edges AS sge WHERE sge.state_group IN ( + SELECT id FROM state_groups AS sg WHERE sg.room_id = ? + )""", + (room_id,), + ) - self.db_pool.simple_delete_many_txn( - txn, - table="state_group_edges", - column="state_group", - values=state_groups_to_delete, - keyvalues={}, + # state_groups_state table has a room_id column but no index on it, unlike state_groups, + # so we delete them by matching the room_id through the state_groups table. + logger.info("[purge] removing %s from state_groups_state", room_id) + txn.execute( + """ + DELETE FROM state_groups_state AS sgs WHERE sgs.state_group IN ( + SELECT id FROM state_groups AS sg WHERE sg.room_id = ? + )""", + (room_id,), ) - # ... and the state groups logger.info("[purge] removing %s from state_groups", room_id) - - self.db_pool.simple_delete_many_txn( + self.db_pool.simple_delete_txn( txn, table="state_groups", - column="id", - values=state_groups_to_delete, - keyvalues={}, + keyvalues={"room_id": room_id}, ) diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ad222e7e2d..9d82c59384 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py
@@ -28,6 +28,11 @@ if TYPE_CHECKING: from synapse.storage.database import LoggingDatabaseConnection +# A string that will be replaced with the appropriate auto increment directive +# for the database engine, expands to an auto incrementing integer primary key. +AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER = "$%AUTO_INCREMENT_PRIMARY_KEY%$" + + class IsolationLevel(IntEnum): READ_COMMITTED: int = 1 REPEATABLE_READ: int = 2 diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 90641d5a18..e4cd359201 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast import psycopg2.extensions from synapse.storage.engines._base import ( + AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER, BaseDatabaseEngine, IncorrectDatabaseSetup, IsolationLevel, @@ -98,8 +99,8 @@ class PostgresEngine( allow_unsafe_locale = self.config.get("allow_unsafe_locale", False) # Are we on a supported PostgreSQL version? - if not allow_outdated_version and self._version < 110000: - raise RuntimeError("Synapse requires PostgreSQL 11 or above.") + if not allow_outdated_version and self._version < 130000: + raise RuntimeError("Synapse requires PostgreSQL 13 or above.") with db_conn.cursor() as txn: txn.execute("SHOW SERVER_ENCODING") @@ -256,4 +257,10 @@ class PostgresEngine( executing the script in its own transaction. The script transaction is left open and it is the responsibility of the caller to commit it. """ + # Replace auto increment placeholder with the appropriate directive + script = script.replace( + AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER, + "BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY", + ) + cursor.execute(f"COMMIT; BEGIN TRANSACTION; {script}") diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index b11094c5c1..9d1795ebe5 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py
@@ -25,6 +25,7 @@ import threading from typing import TYPE_CHECKING, Any, List, Mapping, Optional from synapse.storage.engines import BaseDatabaseEngine +from synapse.storage.engines._base import AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER from synapse.storage.types import Cursor if TYPE_CHECKING: @@ -168,6 +169,11 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): > first. No other implicit transaction control is performed; any transaction > control must be added to sql_script. """ + # Replace auto increment placeholder with the appropriate directive + script = script.replace( + AUTO_INCREMENT_PRIMARY_KEYPLACEHOLDER, "INTEGER PRIMARY KEY AUTOINCREMENT" + ) + # The implementation of `executescript` can be found at # https://github.com/python/cpython/blob/3.11/Modules/_sqlite/cursor.c#L1035. cursor.executescript(f"BEGIN TRANSACTION; {script}") diff --git a/synapse/storage/invite_rule.py b/synapse/storage/invite_rule.py new file mode 100644
index 0000000000..b9d9d1eb62 --- /dev/null +++ b/synapse/storage/invite_rule.py
@@ -0,0 +1,110 @@ +import logging +from enum import Enum +from typing import Optional, Pattern + +from matrix_common.regex import glob_to_regex + +from synapse.types import JsonMapping, UserID + +logger = logging.getLogger(__name__) + + +class InviteRule(Enum): + """Enum to define the action taken when an invite matches a rule.""" + + ALLOW = "allow" + BLOCK = "block" + IGNORE = "ignore" + + +class InviteRulesConfig: + """Class to determine if a given user permits an invite from another user, and the action to take.""" + + def __init__(self, account_data: Optional[JsonMapping]): + self.allowed_users: list[Pattern[str]] = [] + self.ignored_users: list[Pattern[str]] = [] + self.blocked_users: list[Pattern[str]] = [] + + self.allowed_servers: list[Pattern[str]] = [] + self.ignored_servers: list[Pattern[str]] = [] + self.blocked_servers: list[Pattern[str]] = [] + + def process_field( + values: Optional[list[str]], + ruleset: list[Pattern[str]], + rule: InviteRule, + ) -> None: + if isinstance(values, list): + for value in values: + if isinstance(value, str) and len(value) > 0: + # User IDs cannot exceed 255 bytes. Don't process large, potentially + # expensive glob patterns. + if len(value) > 255: + logger.debug( + "Ignoring invite config glob pattern that is >255 bytes: {value}" + ) + continue + + try: + ruleset.append(glob_to_regex(value)) + except Exception as e: + # If for whatever reason we can't process this, just ignore it. + logger.debug( + f"Could not process '{value}' field of invite rule config, ignoring: {e}" + ) + + if account_data: + process_field( + account_data.get("allowed_users"), self.allowed_users, InviteRule.ALLOW + ) + process_field( + account_data.get("ignored_users"), self.ignored_users, InviteRule.IGNORE + ) + process_field( + account_data.get("blocked_users"), self.blocked_users, InviteRule.BLOCK + ) + process_field( + account_data.get("allowed_servers"), + self.allowed_servers, + InviteRule.ALLOW, + ) + process_field( + account_data.get("ignored_servers"), + self.ignored_servers, + InviteRule.IGNORE, + ) + process_field( + account_data.get("blocked_servers"), + self.blocked_servers, + InviteRule.BLOCK, + ) + + def get_invite_rule(self, user_id: str) -> InviteRule: + """Get the invite rule that matches this user. Will return InviteRule.ALLOW if no rules match + + Args: + user_id: The user ID of the inviting user. + + """ + user = UserID.from_string(user_id) + # The order here is important. We always process user rules before server rules + # and we always process in the order of Allow, Ignore, Block. + for patterns, rule in [ + (self.allowed_users, InviteRule.ALLOW), + (self.ignored_users, InviteRule.IGNORE), + (self.blocked_users, InviteRule.BLOCK), + ]: + for regex in patterns: + if regex.match(user_id): + return rule + + for patterns, rule in [ + (self.allowed_servers, InviteRule.ALLOW), + (self.ignored_servers, InviteRule.IGNORE), + (self.blocked_servers, InviteRule.BLOCK), + ]: + for regex in patterns: + if regex.match(user.domain): + return rule + + return InviteRule.ALLOW diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index aaffe5ecc9..bf087702ea 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py
@@ -607,7 +607,7 @@ def _apply_module_schema_files( "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), ) - applied_deltas = {d for d, in cur} + applied_deltas = {d for (d,) in cur} for name, stream in names_and_streams: if name in applied_deltas: continue @@ -710,7 +710,7 @@ def _get_or_create_schema_state( "SELECT file FROM applied_schema_deltas WHERE version >= ?", (current_version,), ) - applied_deltas = tuple(d for d, in txn) + applied_deltas = tuple(d for (d,) in txn) return _SchemaState( current_version=current_version, diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 80c9630867..9dc6c395e8 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py
@@ -40,6 +40,34 @@ class RoomsForUser: @attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) +class RoomsForUserSlidingSync: + room_id: str + sender: Optional[str] + membership: str + event_id: Optional[str] + event_pos: PersistedEventPosition + room_version_id: str + + has_known_state: bool + room_type: Optional[str] + is_encrypted: bool + + +@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) +class RoomsForUserStateReset: + """A version of `RoomsForUser` that supports optional sender and event ID + fields, to handle state resets. State resets can affect room membership + without a corresponding event so that information isn't always available.""" + + room_id: str + sender: Optional[str] + membership: str + event_id: Optional[str] + event_pos: PersistedEventPosition + room_version_id: str + + +@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) class GetRoomsForUserWithStreamOrdering: room_id: str event_pos: PersistedEventPosition diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 581d00346b..3c3b13437e 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py
@@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -19,7 +19,7 @@ # # -SCHEMA_VERSION = 86 # remember to update the list below when updating +SCHEMA_VERSION = 92 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -142,6 +142,32 @@ Changes in SCHEMA_VERSION = 85 Changes in SCHEMA_VERSION = 86 - Add a column `authenticated` to the tables `local_media_repository` and `remote_media_cache` + +Changes in SCHEMA_VERSION = 87 + - Add tables to store Sliding Sync data for quick filtering/sorting + (`sliding_sync_joined_rooms`, `sliding_sync_membership_snapshots`) + - Add tables for storing the per-connection state for sliding sync requests: + sliding_sync_connections, sliding_sync_connection_positions, sliding_sync_connection_required_state, + sliding_sync_connection_room_configs, sliding_sync_connection_streams + +Changes in SCHEMA_VERSION = 88 + - MSC4140: Add `delayed_events` table that keeps track of events that are to + be posted in response to a resettable timeout or an on-demand action. + - Add background update to fix data integrity issue in the + `sliding_sync_membership_snapshots` -> `forgotten` column + +Changes in SCHEMA_VERSION = 89 + - Add `state_groups_pending_deletion` and `state_groups_persisting` tables. + +Changes in SCHEMA_VERSION = 90 + - Add a column `participant` to `room_memberships` table + - Add background update to delete unreferenced state groups. + +Changes in SCHEMA_VERSION = 91 + - Add a `sha256` column to the `local_media_repository` and `remote_media_cache` tables. + +Changes in SCHEMA_VERSION = 92 + - Cleaned up a trigger that was added in #18260 and then reverted. """ diff --git a/synapse/storage/schema/main/delta/25/fts.py b/synapse/storage/schema/main/delta/25/fts.py
index b050cc16a7..c01c1325cb 100644 --- a/synapse/storage/schema/main/delta/25/fts.py +++ b/synapse/storage/schema/main/delta/25/fts.py
@@ -75,8 +75,7 @@ def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> progress_json = json.dumps(progress) sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" + "INSERT into background_updates (update_name, progress_json) VALUES (?, ?)" ) cur.execute(sql, ("event_search", progress_json)) diff --git a/synapse/storage/schema/main/delta/27/ts.py b/synapse/storage/schema/main/delta/27/ts.py
index d7f360b6e6..e6e73e1b77 100644 --- a/synapse/storage/schema/main/delta/27/ts.py +++ b/synapse/storage/schema/main/delta/27/ts.py
@@ -55,8 +55,7 @@ def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> progress_json = json.dumps(progress) sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" + "INSERT into background_updates (update_name, progress_json) VALUES (?, ?)" ) cur.execute(sql, ("event_origin_server_ts", progress_json)) diff --git a/synapse/storage/schema/main/delta/31/search_update.py b/synapse/storage/schema/main/delta/31/search_update.py
index 0e65c9a841..46355122bb 100644 --- a/synapse/storage/schema/main/delta/31/search_update.py +++ b/synapse/storage/schema/main/delta/31/search_update.py
@@ -59,8 +59,7 @@ def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> progress_json = json.dumps(progress) sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" + "INSERT into background_updates (update_name, progress_json) VALUES (?, ?)" ) cur.execute(sql, ("event_search_order", progress_json)) diff --git a/synapse/storage/schema/main/delta/33/event_fields.py b/synapse/storage/schema/main/delta/33/event_fields.py
index 9c02aeda88..53d215337e 100644 --- a/synapse/storage/schema/main/delta/33/event_fields.py +++ b/synapse/storage/schema/main/delta/33/event_fields.py
@@ -55,8 +55,7 @@ def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> progress_json = json.dumps(progress) sql = ( - "INSERT into background_updates (update_name, progress_json)" - " VALUES (?, ?)" + "INSERT into background_updates (update_name, progress_json) VALUES (?, ?)" ) cur.execute(sql, ("event_fields_sender_url", progress_json)) diff --git a/synapse/storage/schema/main/delta/56/unique_user_filter_index.py b/synapse/storage/schema/main/delta/56/unique_user_filter_index.py
index 2461f87d77..b7535dae14 100644 --- a/synapse/storage/schema/main/delta/56/unique_user_filter_index.py +++ b/synapse/storage/schema/main/delta/56/unique_user_filter_index.py
@@ -41,8 +41,6 @@ def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> (user_id, filter_id); DROP TABLE user_filters; ALTER TABLE user_filters_migration RENAME TO user_filters; - """ % ( - select_clause, - ) + """ % (select_clause,) execute_statements_from_stream(cur, StringIO(sql)) diff --git a/synapse/storage/schema/main/delta/61/03recreate_min_depth.py b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py
index 5d3578eaf4..a847ef4147 100644 --- a/synapse/storage/schema/main/delta/61/03recreate_min_depth.py +++ b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py
@@ -23,6 +23,7 @@ This migration handles the process of changing the type of `room_depth.min_depth` to a BIGINT. """ + from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine diff --git a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py
index b4d4b6536b..9ac3d1d31f 100644 --- a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py +++ b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py
@@ -25,6 +25,7 @@ This migration adds triggers to the partial_state_events tables to enforce uniqu Triggers cannot be expressed in .sql files, so we have to use a separate file. """ + from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine diff --git a/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py b/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py
index 93543fca7c..be80a6747d 100644 --- a/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py +++ b/synapse/storage/schema/main/delta/72/07force_update_current_state_events_membership.py
@@ -26,6 +26,7 @@ for its completion can be removed. Note the background job must still remain defined in the database class. """ + from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine diff --git a/synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py b/synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py
index 6609ef0dac..a847a93494 100644 --- a/synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py +++ b/synapse/storage/schema/main/delta/74/04_membership_tables_event_stream_ordering_triggers.py
@@ -24,6 +24,7 @@ This migration adds triggers to the room membership tables to enforce consistency. Triggers cannot be expressed in .sql files, so we have to use a separate file. """ + from synapse.storage.database import LoggingTransaction from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine diff --git a/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py
index ad9c394162..1c823a3aa1 100644 --- a/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py +++ b/synapse/storage/schema/main/delta/78/03event_extremities_constraints.py
@@ -23,6 +23,7 @@ """ This migration adds foreign key constraint to `event_forward_extremities` table. """ + from synapse.storage.background_updates import ( ForeignKeyConstraint, run_validate_constraint_and_delete_rows_schema_delta, diff --git a/synapse/storage/schema/main/delta/86/02_receipts_event_id_index.sql b/synapse/storage/schema/main/delta/86/02_receipts_event_id_index.sql new file mode 100644
index 0000000000..e6db91e5b5 --- /dev/null +++ b/synapse/storage/schema/main/delta/86/02_receipts_event_id_index.sql
@@ -0,0 +1,15 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8602, 'receipts_room_id_event_id_index', '{}'); diff --git a/synapse/storage/schema/main/delta/87/01_sliding_sync_memberships.sql b/synapse/storage/schema/main/delta/87/01_sliding_sync_memberships.sql new file mode 100644
index 0000000000..2f71e541f8 --- /dev/null +++ b/synapse/storage/schema/main/delta/87/01_sliding_sync_memberships.sql
@@ -0,0 +1,169 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- This table is a list/queue used to keep track of which rooms need to be inserted into +-- `sliding_sync_joined_rooms`. We do this to avoid reading from `current_state_events` +-- during the background update to populate `sliding_sync_joined_rooms` which works but +-- it takes a lot of work for the database to grab `DISTINCT` room_ids given how many +-- state events there are for each room. +-- +-- This table is prefilled with every room in the `rooms` table (see the +-- `sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update` background +-- update). This table is also updated whenever we come across stale data so that we can +-- catch-up with all of the new data if Synapse was downgraded (see +-- `_resolve_stale_data_in_sliding_sync_tables`). +-- +-- FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the +-- foreground update for +-- `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by +-- https://github.com/element-hq/synapse/issues/17623) +CREATE TABLE IF NOT EXISTS sliding_sync_joined_rooms_to_recalculate( + room_id TEXT NOT NULL REFERENCES rooms(room_id), + PRIMARY KEY (room_id) +); + +-- A table for storing room meta data (current state relevant to sliding sync) that the +-- local server is still participating in (someone local is joined to the room). +-- +-- We store the joined rooms in separate table from `sliding_sync_membership_snapshots` +-- because we need up-to-date information for joined rooms and it can be shared across +-- everyone who is joined. +-- +-- This table is kept in sync with `current_state_events` which means if the server is +-- no longer participating in a room, the row will be deleted. +CREATE TABLE IF NOT EXISTS sliding_sync_joined_rooms( + room_id TEXT NOT NULL REFERENCES rooms(room_id), + -- The `stream_ordering` of the most-recent/latest event in the room + event_stream_ordering BIGINT NOT NULL REFERENCES events(stream_ordering), + -- The `stream_ordering` of the last event according to the `bump_event_types` + bump_stamp BIGINT, + -- `m.room.create` -> `content.type` (current state) + -- + -- Useful for the `spaces`/`not_spaces` filter in the Sliding Sync API + room_type TEXT, + -- `m.room.name` -> `content.name` (current state) + -- + -- Useful for the room meta data and `room_name_like` filter in the Sliding Sync API + room_name TEXT, + -- `m.room.encryption` -> `content.algorithm` (current state) + -- + -- Useful for the `is_encrypted` filter in the Sliding Sync API + is_encrypted BOOLEAN DEFAULT FALSE NOT NULL, + -- `m.room.tombstone` -> `content.replacement_room` (according to the current state at the + -- time of the membership). + -- + -- Useful for the `include_old_rooms` functionality in the Sliding Sync API + tombstone_successor_room_id TEXT, + PRIMARY KEY (room_id) +); + +-- So we can purge rooms easily. +-- +-- The primary key is already `room_id` + +-- So we can sort by `stream_ordering +CREATE UNIQUE INDEX IF NOT EXISTS sliding_sync_joined_rooms_event_stream_ordering ON sliding_sync_joined_rooms(event_stream_ordering); + +-- A table for storing a snapshot of room meta data (historical current state relevant +-- for sliding sync) at the time of a local user's membership. Only has rows for the +-- latest membership event for a given local user in a room which matches +-- `local_current_membership` . +-- +-- We store all memberships including joins. This makes it easy to reference this table +-- to find all membership for a given user and shares the same semantics as +-- `local_current_membership`. And we get to avoid some table maintenance; if we only +-- stored non-joins, we would have to delete the row for the user when the user joins +-- the room. Stripped state doesn't include the `m.room.tombstone` event, so we just +-- assume that the room doesn't have a tombstone. +-- +-- For remote invite/knocks where the server is not participating in the room, we will +-- use stripped state events to populate this table. We assume that if any stripped +-- state is given, it will include all possible stripped state events types. For +-- example, if stripped state is given but `m.room.encryption` isn't included, we will +-- assume that the room is not encrypted. +-- +-- We don't include `bump_stamp` here because we can just use the `stream_ordering` from +-- the membership event itself as the `bump_stamp`. +CREATE TABLE IF NOT EXISTS sliding_sync_membership_snapshots( + room_id TEXT NOT NULL REFERENCES rooms(room_id), + user_id TEXT NOT NULL, + -- Useful to be able to tell leaves from kicks (where the `user_id` is different from the `sender`) + sender TEXT NOT NULL, + membership_event_id TEXT NOT NULL REFERENCES events(event_id), + membership TEXT NOT NULL, + -- This is an integer just to match `room_memberships` and also means we don't need + -- to do any casting. + forgotten INTEGER DEFAULT 0 NOT NULL, + -- `stream_ordering` of the `membership_event_id` + event_stream_ordering BIGINT NOT NULL REFERENCES events(stream_ordering), + -- `instance_name` of the worker that persisted the `membership_event_id`. + -- Useful for crafting `PersistedEventPosition(...)` + event_instance_name TEXT NOT NULL, + -- For remote invites/knocks that don't include any stripped state, we want to be + -- able to distinguish between a room with `None` as valid value for some state and + -- room where the state is completely unknown. Basically, this should be True unless + -- no stripped state was provided for a remote invite/knock (False). + has_known_state BOOLEAN DEFAULT FALSE NOT NULL, + -- `m.room.create` -> `content.type` (according to the current state at the time of + -- the membership). + -- + -- Useful for the `spaces`/`not_spaces` filter in the Sliding Sync API + room_type TEXT, + -- `m.room.name` -> `content.name` (according to the current state at the time of + -- the membership). + -- + -- Useful for the room meta data and `room_name_like` filter in the Sliding Sync API + room_name TEXT, + -- `m.room.encryption` -> `content.algorithm` (according to the current state at the + -- time of the membership). + -- + -- Useful for the `is_encrypted` filter in the Sliding Sync API + is_encrypted BOOLEAN DEFAULT FALSE NOT NULL, + -- `m.room.tombstone` -> `content.replacement_room` (according to the current state at the + -- time of the membership). + -- + -- Useful for the `include_old_rooms` functionality in the Sliding Sync API + tombstone_successor_room_id TEXT, + PRIMARY KEY (room_id, user_id) +); + +-- So we can purge rooms easily. +-- +-- Since we're using a multi-column index as the primary key (room_id, user_id), the +-- first index column (room_id) is always usable for searching so we don't need to +-- create a separate index for it. +-- +-- CREATE INDEX IF NOT EXISTS sliding_sync_membership_snapshots_room_id ON sliding_sync_membership_snapshots(room_id); + +-- So we can fetch all rooms for a given user +CREATE INDEX IF NOT EXISTS sliding_sync_membership_snapshots_user_id ON sliding_sync_membership_snapshots(user_id); +-- So we can sort by `stream_ordering +CREATE UNIQUE INDEX IF NOT EXISTS sliding_sync_membership_snapshots_event_stream_ordering ON sliding_sync_membership_snapshots(event_stream_ordering); + + +-- Add a series of background updates to populate the new `sliding_sync_joined_rooms` table: +-- +-- 1. Add a background update to prefill `sliding_sync_joined_rooms_to_recalculate`. +-- We do a one-shot bulk insert from the `rooms` table to prefill. +-- 2. Add a background update to populate the new `sliding_sync_joined_rooms` table +-- based on the rooms listed in the `sliding_sync_joined_rooms_to_recalculate` +-- table. +-- +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8701, 'sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update', '{}'); +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (8701, 'sliding_sync_joined_rooms_bg_update', '{}', 'sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update'); + +-- Add a background updates to populate the new `sliding_sync_membership_snapshots` table +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8701, 'sliding_sync_membership_snapshots_bg_update', '{}'); diff --git a/synapse/storage/schema/main/delta/87/02_per_connection_state.sql b/synapse/storage/schema/main/delta/87/02_per_connection_state.sql new file mode 100644
index 0000000000..59bc14a2c9 --- /dev/null +++ b/synapse/storage/schema/main/delta/87/02_per_connection_state.sql
@@ -0,0 +1,81 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + + +-- Table to track active sliding sync connections. +-- +-- A new connection will be created for every sliding sync request without a +-- `since` token for a given `conn_id` for a device.# +-- +-- Once a new connection is created and used we delete all other connections for +-- the `conn_id`. +CREATE TABLE sliding_sync_connections( + connection_key $%AUTO_INCREMENT_PRIMARY_KEY%$, + user_id TEXT NOT NULL, + -- Generally the device ID, but may be something else for e.g. puppeted accounts. + effective_device_id TEXT NOT NULL, + conn_id TEXT NOT NULL, + created_ts BIGINT NOT NULL +); + +CREATE INDEX sliding_sync_connections_idx ON sliding_sync_connections(user_id, effective_device_id, conn_id); +CREATE INDEX sliding_sync_connections_ts_idx ON sliding_sync_connections(created_ts); + +-- We track per-connection state by associating changes to the state with +-- connection positions. This ensures that we correctly track state even if we +-- see retries of requests. +-- +-- If the client starts a "new" connection (by not specifying a since token), +-- we'll clear out the other connections (to ensure that we don't end up with +-- lots of connection keys). +CREATE TABLE sliding_sync_connection_positions( + connection_position $%AUTO_INCREMENT_PRIMARY_KEY%$, + connection_key BIGINT NOT NULL REFERENCES sliding_sync_connections(connection_key) ON DELETE CASCADE, + created_ts BIGINT NOT NULL +); + +CREATE INDEX sliding_sync_connection_positions_key ON sliding_sync_connection_positions(connection_key); +CREATE INDEX sliding_sync_connection_positions_ts_idx ON sliding_sync_connection_positions(created_ts); + + +-- To save space we deduplicate the `required_state` json by assigning IDs to +-- different values. +CREATE TABLE sliding_sync_connection_required_state( + required_state_id $%AUTO_INCREMENT_PRIMARY_KEY%$, + connection_key BIGINT NOT NULL REFERENCES sliding_sync_connections(connection_key) ON DELETE CASCADE, + required_state TEXT NOT NULL -- We store this as a json list of event type / state key tuples. +); + +CREATE INDEX sliding_sync_connection_required_state_conn_pos ON sliding_sync_connection_required_state(connection_key); + + +-- Stores the room configs we have seen for rooms in a connection. +CREATE TABLE sliding_sync_connection_room_configs( + connection_position BIGINT NOT NULL REFERENCES sliding_sync_connection_positions(connection_position) ON DELETE CASCADE, + room_id TEXT NOT NULL, + timeline_limit BIGINT NOT NULL, + required_state_id BIGINT NOT NULL REFERENCES sliding_sync_connection_required_state(required_state_id) +); + +CREATE UNIQUE INDEX sliding_sync_connection_room_configs_idx ON sliding_sync_connection_room_configs(connection_position, room_id); + +-- Stores what data we have sent for given streams down given connections. +CREATE TABLE sliding_sync_connection_streams( + connection_position BIGINT NOT NULL REFERENCES sliding_sync_connection_positions(connection_position) ON DELETE CASCADE, + stream TEXT NOT NULL, -- e.g. "events" or "receipts" + room_id TEXT NOT NULL, + room_status TEXT NOT NULL, -- "live" or "previously", i.e. the `HaveSentRoomFlag` value + last_token TEXT -- For "previously" the token for the stream we have sent up to. +); + +CREATE UNIQUE INDEX sliding_sync_connection_streams_idx ON sliding_sync_connection_streams(connection_position, room_id, stream); diff --git a/synapse/storage/schema/main/delta/87/03_current_state_index.sql b/synapse/storage/schema/main/delta/87/03_current_state_index.sql new file mode 100644
index 0000000000..76b974271c --- /dev/null +++ b/synapse/storage/schema/main/delta/87/03_current_state_index.sql
@@ -0,0 +1,19 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + + +-- Add a background updates to add a new index: +-- `current_state_events(room_id, membership) WHERE type = 'm.room.member' +-- This makes counting membership in rooms (for syncs) much faster +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8701, 'current_state_events_members_room_index', '{}'); diff --git a/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql b/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql new file mode 100644
index 0000000000..78ba5129af --- /dev/null +++ b/synapse/storage/schema/main/delta/88/01_add_delayed_events.sql
@@ -0,0 +1,43 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +CREATE TABLE delayed_events ( + delay_id TEXT NOT NULL, + user_localpart TEXT NOT NULL, + device_id TEXT, + delay BIGINT NOT NULL, + send_ts BIGINT NOT NULL, + room_id TEXT NOT NULL, + event_type TEXT NOT NULL, + state_key TEXT, + origin_server_ts BIGINT, + content bytea NOT NULL, + is_processed BOOLEAN NOT NULL DEFAULT FALSE, + PRIMARY KEY (user_localpart, delay_id) +); + +CREATE INDEX delayed_events_send_ts ON delayed_events (send_ts); +CREATE INDEX delayed_events_is_processed ON delayed_events (is_processed); +CREATE INDEX delayed_events_room_state_event_idx ON delayed_events (room_id, event_type, state_key) WHERE state_key IS NOT NULL; + +CREATE TABLE delayed_events_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT NOT NULL, + CHECK (Lock='X') +); + +-- Start processing events from the point this migration was run, rather +-- than the beginning of time. +INSERT INTO delayed_events_stream_pos ( + stream_id +) SELECT COALESCE(MAX(stream_ordering), 0) from events; diff --git a/synapse/storage/schema/main/delta/88/01_custom_profile_fields.sql b/synapse/storage/schema/main/delta/88/01_custom_profile_fields.sql new file mode 100644
index 0000000000..63cbd7ffa9 --- /dev/null +++ b/synapse/storage/schema/main/delta/88/01_custom_profile_fields.sql
@@ -0,0 +1,15 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 Patrick Cloke +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Custom profile fields. +ALTER TABLE profiles ADD COLUMN fields JSONB; diff --git a/synapse/storage/schema/main/delta/88/02_fix_sliding_sync_membership_snapshots_forgotten_column.sql b/synapse/storage/schema/main/delta/88/02_fix_sliding_sync_membership_snapshots_forgotten_column.sql new file mode 100644
index 0000000000..4de46af2fc --- /dev/null +++ b/synapse/storage/schema/main/delta/88/02_fix_sliding_sync_membership_snapshots_forgotten_column.sql
@@ -0,0 +1,21 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Add a background update to update the `sliding_sync_membership_snapshots` -> +-- `forgotten` column to be in sync with the `room_memberships` table. +-- +-- For any room that someone has forgotten and subsequently re-joined or had any new +-- membership on, we need to go and update the column to match the `room_memberships` +-- table as it has fallen out of sync. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8802, 'sliding_sync_membership_snapshots_fix_forgotten_column_bg_update', '{}'); diff --git a/synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql b/synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql new file mode 100644
index 0000000000..7712ea68ad --- /dev/null +++ b/synapse/storage/schema/main/delta/88/03_add_otk_ts_added_index.sql
@@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + + +-- Add an index on (user_id, device_id, algorithm, ts_added_ms) on e2e_one_time_keys_json, so that OTKs can +-- efficiently be issued in the same order they were uploaded. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8803, 'add_otk_ts_added_index', '{}'); diff --git a/synapse/storage/schema/main/delta/88/04_current_state_delta_index.sql b/synapse/storage/schema/main/delta/88/04_current_state_delta_index.sql new file mode 100644
index 0000000000..0ee78df1a0 --- /dev/null +++ b/synapse/storage/schema/main/delta/88/04_current_state_delta_index.sql
@@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + + +-- Add an index on `current_state_delta_stream(room_id, stream_id)` to allow +-- efficient per-room lookups. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8804, 'current_state_delta_stream_room_index', '{}'); diff --git a/synapse/storage/schema/main/delta/88/05_drop_old_otks.sql.postgres b/synapse/storage/schema/main/delta/88/05_drop_old_otks.sql.postgres new file mode 100644
index 0000000000..93a68836ee --- /dev/null +++ b/synapse/storage/schema/main/delta/88/05_drop_old_otks.sql.postgres
@@ -0,0 +1,19 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Until Synapse 1.119, Synapse used to issue one-time-keys in a random order, leading to the possibility +-- that it could still have old OTKs that the client has dropped. +-- +-- We create a scheduled task which will drop old OTKs, to flush them out. +INSERT INTO scheduled_tasks(id, action, status, timestamp) + VALUES ('delete_old_otks_task', 'delete_old_otks', 'scheduled', extract(epoch from current_timestamp) * 1000); diff --git a/synapse/storage/schema/main/delta/88/05_drop_old_otks.sql.sqlite b/synapse/storage/schema/main/delta/88/05_drop_old_otks.sql.sqlite new file mode 100644
index 0000000000..cdc2b5d211 --- /dev/null +++ b/synapse/storage/schema/main/delta/88/05_drop_old_otks.sql.sqlite
@@ -0,0 +1,19 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Until Synapse 1.119, Synapse used to issue one-time-keys in a random order, leading to the possibility +-- that it could still have old OTKs that the client has dropped. +-- +-- We create a scheduled task which will drop old OTKs, to flush them out. +INSERT INTO scheduled_tasks(id, action, status, timestamp) + VALUES ('delete_old_otks_task', 'delete_old_otks', 'scheduled', strftime('%s', 'now') * 1000); diff --git a/synapse/storage/schema/main/delta/88/05_sliding_sync_room_config_index.sql b/synapse/storage/schema/main/delta/88/05_sliding_sync_room_config_index.sql new file mode 100644
index 0000000000..7b2e18a84b --- /dev/null +++ b/synapse/storage/schema/main/delta/88/05_sliding_sync_room_config_index.sql
@@ -0,0 +1,20 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + + +-- Add an index on sliding_sync_connection_room_configs(required_state_id), so +-- that when we delete entries in `sliding_sync_connection_required_state` it's +-- efficient for Postgres to check they've been deleted from +-- `sliding_sync_connection_room_configs` too +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8805, 'sliding_sync_connection_room_configs_required_state_id_idx', '{}'); diff --git a/synapse/storage/schema/main/delta/88/06_events_received_ts_index.sql b/synapse/storage/schema/main/delta/88/06_events_received_ts_index.sql new file mode 100644
index 0000000000..d70a4a8dbc --- /dev/null +++ b/synapse/storage/schema/main/delta/88/06_events_received_ts_index.sql
@@ -0,0 +1,17 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Add an index on `events.received_ts` for `m.room.member` events to allow for +-- efficient lookup of events by timestamp in some Admin API's +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8806, 'events_received_ts_index', '{}'); diff --git a/synapse/storage/schema/main/delta/89/01_sliding_sync_membership_snapshot_index.sql b/synapse/storage/schema/main/delta/89/01_sliding_sync_membership_snapshot_index.sql new file mode 100644
index 0000000000..7799cffdce --- /dev/null +++ b/synapse/storage/schema/main/delta/89/01_sliding_sync_membership_snapshot_index.sql
@@ -0,0 +1,15 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (8901, 'sliding_sync_membership_snapshots_membership_event_id_idx', '{}'); diff --git a/synapse/storage/schema/main/delta/90/01_add_column_participant_room_memberships_table.sql b/synapse/storage/schema/main/delta/90/01_add_column_participant_room_memberships_table.sql new file mode 100644
index 0000000000..dafd046499 --- /dev/null +++ b/synapse/storage/schema/main/delta/90/01_add_column_participant_room_memberships_table.sql
@@ -0,0 +1,16 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Add a column `participant` to `room_memberships` table to track whether a room member has sent +-- a `m.room.message` or `m.room.encrypted` event into a room they are a member of +ALTER TABLE room_memberships ADD COLUMN participant BOOLEAN DEFAULT FALSE; \ No newline at end of file diff --git a/synapse/storage/schema/main/delta/91/01_media_hash.sql b/synapse/storage/schema/main/delta/91/01_media_hash.sql new file mode 100644
index 0000000000..34a372f1ed --- /dev/null +++ b/synapse/storage/schema/main/delta/91/01_media_hash.sql
@@ -0,0 +1,28 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Store the SHA256 content hash of media files. +ALTER TABLE local_media_repository ADD COLUMN sha256 TEXT; +ALTER TABLE remote_media_cache ADD COLUMN sha256 TEXT; + +-- Add a background updates to handle creating the new index. +-- +-- Note that the ordering of the update is not following the usual scheme. This +-- is because when upgrading from Synapse 1.127, this index is fairly important +-- to have up quickly, so that it doesn't tank performance, which is why it is +-- scheduled before other background updates in the 1.127 -> 1.128 upgrade +INSERT INTO + background_updates (ordering, update_name, progress_json) +VALUES + (8890, 'local_media_repository_sha256_idx', '{}'), + (8891, 'remote_media_cache_sha256_idx', '{}'); diff --git a/synapse/storage/schema/main/delta/92/01_remove_trigger.sql.postgres b/synapse/storage/schema/main/delta/92/01_remove_trigger.sql.postgres new file mode 100644
index 0000000000..e9f160cdcc --- /dev/null +++ b/synapse/storage/schema/main/delta/92/01_remove_trigger.sql.postgres
@@ -0,0 +1,16 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Removes the trigger that was added in #18260 and then reverted +DROP TRIGGER IF EXISTS event_stats_increment_counts_trigger ON events; +DROP FUNCTION IF EXISTS event_stats_increment_counts(); diff --git a/synapse/storage/schema/main/delta/92/01_remove_trigger.sql.sqlite b/synapse/storage/schema/main/delta/92/01_remove_trigger.sql.sqlite new file mode 100644
index 0000000000..b5f084dde8 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/01_remove_trigger.sql.sqlite
@@ -0,0 +1,16 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Removes the trigger that was added in #18260 and then reverted +DROP TRIGGER IF EXISTS event_stats_events_insert_trigger; +DROP TRIGGER IF EXISTS event_stats_events_delete_trigger; diff --git a/synapse/storage/schema/main/delta/92/02_remove_populate_participant_bg_update.sql b/synapse/storage/schema/main/delta/92/02_remove_populate_participant_bg_update.sql new file mode 100644
index 0000000000..e1f377c37d --- /dev/null +++ b/synapse/storage/schema/main/delta/92/02_remove_populate_participant_bg_update.sql
@@ -0,0 +1,17 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Remove the background update if it was scheduled, as it is not rollback-safe +-- See https://github.com/element-hq/synapse/issues/18356 for context +DELETE FROM background_updates +WHERE update_name = 'populate_participant_bg_update'; \ No newline at end of file diff --git a/synapse/storage/schema/main/delta/92/04_ss_membership_snapshot_idx.sql b/synapse/storage/schema/main/delta/92/04_ss_membership_snapshot_idx.sql new file mode 100644
index 0000000000..6f5b7cb06e --- /dev/null +++ b/synapse/storage/schema/main/delta/92/04_ss_membership_snapshot_idx.sql
@@ -0,0 +1,16 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- So we can fetch all rooms for a given user sorted by stream order +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (9204, 'sliding_sync_membership_snapshots_user_id_stream_ordering', '{}'); diff --git a/synapse/storage/schema/main/delta/92/05_fixup_max_depth_cap.sql b/synapse/storage/schema/main/delta/92/05_fixup_max_depth_cap.sql new file mode 100644
index 0000000000..c1ebf8b58b --- /dev/null +++ b/synapse/storage/schema/main/delta/92/05_fixup_max_depth_cap.sql
@@ -0,0 +1,17 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Background update that fixes any events with a topological ordering above the +-- MAX_DEPTH value. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (9205, 'fixup_max_depth_cap', '{}'); diff --git a/synapse/storage/schema/state/delta/89/01_state_groups_deletion.sql b/synapse/storage/schema/state/delta/89/01_state_groups_deletion.sql new file mode 100644
index 0000000000..d4cb27a3a2 --- /dev/null +++ b/synapse/storage/schema/state/delta/89/01_state_groups_deletion.sql
@@ -0,0 +1,39 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- See the `StateDeletionDataStore` for details of these tables. + +-- We add state groups to this table when we want to later delete them. The +-- `insertion_ts` column indicates when the state group was proposed for +-- deletion (rather than when it should be deleted). +CREATE TABLE IF NOT EXISTS state_groups_pending_deletion ( + sequence_number $%AUTO_INCREMENT_PRIMARY_KEY%$, + state_group BIGINT NOT NULL, + insertion_ts BIGINT NOT NULL +); + +CREATE UNIQUE INDEX state_groups_pending_deletion_state_group ON state_groups_pending_deletion(state_group); +CREATE INDEX state_groups_pending_deletion_insertion_ts ON state_groups_pending_deletion(insertion_ts); + + +-- Holds the state groups the worker is currently persisting. +-- +-- The `sequence_number` column of the `state_groups_pending_deletion` table +-- *must* be updated whenever a state group may have become referenced. +CREATE TABLE IF NOT EXISTS state_groups_persisting ( + state_group BIGINT NOT NULL, + instance_name TEXT NOT NULL, + PRIMARY KEY (state_group, instance_name) +); + +CREATE INDEX state_groups_persisting_instance_name ON state_groups_persisting(instance_name); diff --git a/synapse/storage/schema/state/delta/90/02_delete_unreferenced_state_groups.sql b/synapse/storage/schema/state/delta/90/02_delete_unreferenced_state_groups.sql new file mode 100644
index 0000000000..55a038e2b8 --- /dev/null +++ b/synapse/storage/schema/state/delta/90/02_delete_unreferenced_state_groups.sql
@@ -0,0 +1,16 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Add a background update to delete any unreferenced state groups +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (9002, 'mark_unreferenced_state_groups_for_deletion_bg_update', '{}'); diff --git a/synapse/storage/schema/state/delta/90/03_remove_old_deletion_bg_update.sql b/synapse/storage/schema/state/delta/90/03_remove_old_deletion_bg_update.sql new file mode 100644
index 0000000000..1cc6d612b6 --- /dev/null +++ b/synapse/storage/schema/state/delta/90/03_remove_old_deletion_bg_update.sql
@@ -0,0 +1,15 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- <https://www.gnu.org/licenses/agpl-3.0.html>. + +-- Remove the old unreferenced state group deletion background update if it exists +DELETE FROM background_updates WHERE update_name = 'delete_unreferenced_state_groups_bg_update'; diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 74f60cc590..4329d88c9a 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py
@@ -26,14 +26,13 @@ from typing import ( List, Mapping, Optional, + Protocol, Sequence, Tuple, Type, Union, ) -from typing_extensions import Protocol - """ Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 """ diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi
index 1682d0d151..7d3422572d 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi
@@ -10,7 +10,7 @@ # See the GNU Affero General Public License for more details: # <https://www.gnu.org/licenses/agpl-3.0.html>. -from typing import Optional +from typing import List, Mapping, Optional, Tuple from synapse.types import JsonDict @@ -105,3 +105,29 @@ class EventInternalMetadata: def is_notifiable(self) -> bool: """Whether this event can trigger a push notification""" + +def event_visible_to_server( + sender: str, + target_server_name: str, + history_visibility: str, + erased_senders: Mapping[str, bool], + partial_state_invisible: bool, + memberships: List[Tuple[str, str]], +) -> bool: + """Determine whether the server is allowed to see the unredacted event. + + Args: + sender: The sender of the event. + target_server_name: The server we want to send the event to. + history_visibility: The history_visibility value at the event. + erased_senders: A mapping of users and whether they have requested erasure. If a + user is not in the map, it is treated as though they haven't requested erasure. + partial_state_invisible: Whether the event should be treated as invisible due to + the partial state status of the room. + memberships: A list of membership state information at the event for users + matching the `target_server_name`. Each list item must contain a tuple of + (state_key, membership). + + Returns: + Whether the server is allowed to see the unredacted event. + """ diff --git a/synapse/synapse_rust/push.pyi b/synapse/synapse_rust/push.pyi
index 27a974e1bb..3f317c3288 100644 --- a/synapse/synapse_rust/push.pyi +++ b/synapse/synapse_rust/push.pyi
@@ -48,6 +48,7 @@ class FilteredPushRules: msc3381_polls_enabled: bool, msc3664_enabled: bool, msc4028_push_encrypted_events: bool, + msc4210_enabled: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... @@ -65,6 +66,7 @@ class PushRuleEvaluator: related_event_match_enabled: bool, room_version_feature_flags: Tuple[str, ...], msc3931_enabled: bool, + msc4210_enabled: bool, ): ... def run( self, diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 5259550f1c..5549f3c9f8 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py
@@ -40,6 +40,7 @@ from typing import ( Set, Tuple, Type, + TypedDict, TypeVar, Union, overload, @@ -49,7 +50,7 @@ import attr from immutabledict import immutabledict from signedjson.key import decode_verify_key_bytes from signedjson.types import VerifyKey -from typing_extensions import Self, TypedDict +from typing_extensions import Self from unpaddedbase64 import decode_base64 from zope.interface import Interface @@ -664,6 +665,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): @classmethod async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": + # Check that it looks like a Synapse token first. We do this so that + # we don't log at the exception-level for obviously incorrect tokens. + if not string or string[0] not in ("s", "t", "m"): + raise SynapseError(400, f"Invalid room stream token {string:!r}") + try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -883,8 +889,7 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken): def __str__(self) -> str: instances = ", ".join(f"{k}: {v}" for k, v in sorted(self.instance_map.items())) return ( - f"MultiWriterStreamToken(stream: {self.stream}, " - f"instances: {{{instances}}})" + f"MultiWriterStreamToken(stream: {self.stream}, instances: {{{instances}}})" ) @@ -1308,7 +1313,7 @@ class DeviceListUpdates: def get_verify_key_from_cross_signing_key( - key_info: Mapping[str, Any] + key_info: Mapping[str, Any], ) -> Tuple[str, VerifyKey]: """Get the key ID and signedjson verify key from a cross-signing key dict diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 363f060bef..f2fbc1dddf 100644 --- a/synapse/types/handlers/__init__.py +++ b/synapse/types/handlers/__init__.py
@@ -17,33 +17,23 @@ # [This file includes modifications made by New Vector Limited] # # -from enum import Enum -from typing import TYPE_CHECKING, Dict, Final, List, Mapping, Optional, Sequence, Tuple -import attr -from typing_extensions import TypedDict -from synapse._pydantic_compat import HAS_PYDANTIC_V2 +from typing import List, Optional, TypedDict -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import Extra -else: - from pydantic import Extra +from synapse.api.constants import EventTypes -from synapse.events import EventBase -from synapse.types import ( - DeviceListUpdates, - JsonDict, - JsonMapping, - Requester, - SlidingSyncStreamToken, - StreamToken, - UserID, -) -from synapse.types.rest.client import SlidingSyncBody - -if TYPE_CHECKING: - from synapse.handlers.relations import BundledAggregations +# Sliding Sync: The event types that clients should consider as new activity and affect +# the `bump_stamp` +SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES = { + EventTypes.Create, + EventTypes.Message, + EventTypes.Encrypted, + EventTypes.Sticker, + EventTypes.CallInvite, + EventTypes.PollStart, + EventTypes.LiveLocationShareStart, +} class ShutdownRoomParams(TypedDict): @@ -101,331 +91,3 @@ class ShutdownRoomResponse(TypedDict): failed_to_kick_users: List[str] local_aliases: List[str] new_room_id: Optional[str] - - -class SlidingSyncConfig(SlidingSyncBody): - """ - Inherit from `SlidingSyncBody` since we need all of the same fields and add a few - extra fields that we need in the handler - """ - - user: UserID - requester: Requester - - # Pydantic config - class Config: - # By default, ignore fields that we don't recognise. - extra = Extra.ignore - # By default, don't allow fields to be reassigned after parsing. - allow_mutation = False - # Allow custom types like `UserID` to be used in the model - arbitrary_types_allowed = True - - -class OperationType(Enum): - """ - Represents the operation types in a Sliding Sync window. - - Attributes: - SYNC: Sets a range of entries. Clients SHOULD discard what they previous knew about - entries in this range. - INSERT: Sets a single entry. If the position is not empty then clients MUST move - entries to the left or the right depending on where the closest empty space is. - DELETE: Remove a single entry. Often comes before an INSERT to allow entries to move - places. - INVALIDATE: Remove a range of entries. Clients MAY persist the invalidated range for - offline support, but they should be treated as empty when additional operations - which concern indexes in the range arrive from the server. - """ - - SYNC: Final = "SYNC" - INSERT: Final = "INSERT" - DELETE: Final = "DELETE" - INVALIDATE: Final = "INVALIDATE" - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class SlidingSyncResult: - """ - The Sliding Sync result to be serialized to JSON for a response. - - Attributes: - next_pos: The next position token in the sliding window to request (next_batch). - lists: Sliding window API. A map of list key to list results. - rooms: Room subscription API. A map of room ID to room results. - extensions: Extensions API. A map of extension key to extension results. - """ - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class RoomResult: - """ - Attributes: - name: Room name or calculated room name. - avatar: Room avatar - heroes: List of stripped membership events (containing `user_id` and optionally - `avatar_url` and `displayname`) for the users used to calculate the room name. - is_dm: Flag to specify whether the room is a direct-message room (most likely - between two people). - initial: Flag which is set when this is the first time the server is sending this - data on this connection. Clients can use this flag to replace or update - their local state. When there is an update, servers MUST omit this flag - entirely and NOT send "initial":false as this is wasteful on bandwidth. The - absence of this flag means 'false'. - required_state: The current state of the room - timeline: Latest events in the room. The last event is the most recent. - bundled_aggregations: A mapping of event ID to the bundled aggregations for - the timeline events above. This allows clients to show accurate reaction - counts (or edits, threads), even if some of the reaction events were skipped - over in a gappy sync. - stripped_state: Stripped state events (for rooms where the usre is - invited/knocked). Same as `rooms.invite.$room_id.invite_state` in sync v2, - absent on joined/left rooms - prev_batch: A token that can be passed as a start parameter to the - `/rooms/<room_id>/messages` API to retrieve earlier messages. - limited: True if there are more events than `timeline_limit` looking - backwards from the `response.pos` to the `request.pos`. - num_live: The number of timeline events which have just occurred and are not historical. - The last N events are 'live' and should be treated as such. This is mostly - useful to determine whether a given @mention event should make a noise or not. - Clients cannot rely solely on the absence of `initial: true` to determine live - events because if a room not in the sliding window bumps into the window because - of an @mention it will have `initial: true` yet contain a single live event - (with potentially other old events in the timeline). - bump_stamp: The `stream_ordering` of the last event according to the - `bump_event_types`. This helps clients sort more readily without them - needing to pull in a bunch of the timeline to determine the last activity. - `bump_event_types` is a thing because for example, we don't want display - name changes to mark the room as unread and bump it to the top. For - encrypted rooms, we just have to consider any activity as a bump because we - can't see the content and the client has to figure it out for themselves. - joined_count: The number of users with membership of join, including the client's - own user ID. (same as sync `v2 m.joined_member_count`) - invited_count: The number of users with membership of invite. (same as sync v2 - `m.invited_member_count`) - notification_count: The total number of unread notifications for this room. (same - as sync v2) - highlight_count: The number of unread notifications for this room with the highlight - flag set. (same as sync v2) - """ - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class StrippedHero: - user_id: str - display_name: Optional[str] - avatar_url: Optional[str] - - name: Optional[str] - avatar: Optional[str] - heroes: Optional[List[StrippedHero]] - is_dm: bool - initial: bool - # Should be empty for invite/knock rooms with `stripped_state` - required_state: List[EventBase] - # Should be empty for invite/knock rooms with `stripped_state` - timeline_events: List[EventBase] - bundled_aggregations: Optional[Dict[str, "BundledAggregations"]] - # Optional because it's only relevant to invite/knock rooms - stripped_state: List[JsonDict] - # Only optional because it won't be included for invite/knock rooms with `stripped_state` - prev_batch: Optional[StreamToken] - # Only optional because it won't be included for invite/knock rooms with `stripped_state` - limited: Optional[bool] - # Only optional because it won't be included for invite/knock rooms with `stripped_state` - num_live: Optional[int] - bump_stamp: int - joined_count: int - invited_count: int - notification_count: int - highlight_count: int - - def __bool__(self) -> bool: - return ( - # If this is the first time the client is seeing the room, we should not filter it out - # under any circumstance. - self.initial - # We need to let the client know if there are any new events - or bool(self.required_state) - or bool(self.timeline_events) - or bool(self.stripped_state) - ) - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class SlidingWindowList: - """ - Attributes: - count: The total number of entries in the list. Always present if this list - is. - ops: The sliding list operations to perform. - """ - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class Operation: - """ - Attributes: - op: The operation type to perform. - range: Which index positions are affected by this operation. These are - both inclusive. - room_ids: Which room IDs are affected by this operation. These IDs match - up to the positions in the `range`, so the last room ID in this list - matches the 9th index. The room data is held in a separate object. - """ - - op: OperationType - range: Tuple[int, int] - room_ids: List[str] - - count: int - ops: List[Operation] - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class Extensions: - """Responses for extensions - - Attributes: - to_device: The to-device extension (MSC3885) - e2ee: The E2EE device extension (MSC3884) - """ - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class ToDeviceExtension: - """The to-device extension (MSC3885) - - Attributes: - next_batch: The to-device stream token the client should use - to get more results - events: A list of to-device messages for the client - """ - - next_batch: str - events: Sequence[JsonMapping] - - def __bool__(self) -> bool: - return bool(self.events) - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class E2eeExtension: - """The E2EE device extension (MSC3884) - - Attributes: - device_list_updates: List of user_ids whose devices have changed or left (only - present on incremental syncs). - device_one_time_keys_count: Map from key algorithm to the number of - unclaimed one-time keys currently held on the server for this device. If - an algorithm is unlisted, the count for that algorithm is assumed to be - zero. If this entire parameter is missing, the count for all algorithms - is assumed to be zero. - device_unused_fallback_key_types: List of unused fallback key algorithms - for this device. - """ - - # Only present on incremental syncs - device_list_updates: Optional[DeviceListUpdates] - device_one_time_keys_count: Mapping[str, int] - device_unused_fallback_key_types: Sequence[str] - - def __bool__(self) -> bool: - # Note that "signed_curve25519" is always returned in key count responses - # regardless of whether we uploaded any keys for it. This is necessary until - # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. - # - # Also related: - # https://github.com/element-hq/element-android/issues/3725 and - # https://github.com/matrix-org/synapse/issues/10456 - default_otk = self.device_one_time_keys_count.get("signed_curve25519") - more_than_default_otk = len(self.device_one_time_keys_count) > 1 or ( - default_otk is not None and default_otk > 0 - ) - - return bool( - more_than_default_otk - or self.device_list_updates - or self.device_unused_fallback_key_types - ) - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class AccountDataExtension: - """The Account Data extension (MSC3959) - - Attributes: - global_account_data_map: Mapping from `type` to `content` of global account - data events. - account_data_by_room_map: Mapping from room_id to mapping of `type` to - `content` of room account data events. - """ - - global_account_data_map: Mapping[str, JsonMapping] - account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] - - def __bool__(self) -> bool: - return bool( - self.global_account_data_map or self.account_data_by_room_map - ) - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class ReceiptsExtension: - """The Receipts extension (MSC3960) - - Attributes: - room_id_to_receipt_map: Mapping from room_id to `m.receipt` ephemeral - event (type, content) - """ - - room_id_to_receipt_map: Mapping[str, JsonMapping] - - def __bool__(self) -> bool: - return bool(self.room_id_to_receipt_map) - - @attr.s(slots=True, frozen=True, auto_attribs=True) - class TypingExtension: - """The Typing Notification extension (MSC3961) - - Attributes: - room_id_to_typing_map: Mapping from room_id to `m.typing` ephemeral - event (type, content) - """ - - room_id_to_typing_map: Mapping[str, JsonMapping] - - def __bool__(self) -> bool: - return bool(self.room_id_to_typing_map) - - to_device: Optional[ToDeviceExtension] = None - e2ee: Optional[E2eeExtension] = None - account_data: Optional[AccountDataExtension] = None - receipts: Optional[ReceiptsExtension] = None - typing: Optional[TypingExtension] = None - - def __bool__(self) -> bool: - return bool( - self.to_device - or self.e2ee - or self.account_data - or self.receipts - or self.typing - ) - - next_pos: SlidingSyncStreamToken - lists: Dict[str, SlidingWindowList] - rooms: Dict[str, RoomResult] - extensions: Extensions - - def __bool__(self) -> bool: - """Make the result appear empty if there are no updates. This is used - to tell if the notifier needs to wait for more events when polling for - events. - """ - # We don't include `self.lists` here, as a) `lists` is always non-empty even if - # there are no changes, and b) since we're sorting rooms by `stream_ordering` of - # the latest activity, anything that would cause the order to change would end - # up in `self.rooms` and cause us to send down the change. - return bool(self.rooms or self.extensions) - - @staticmethod - def empty(next_pos: SlidingSyncStreamToken) -> "SlidingSyncResult": - "Return a new empty result" - return SlidingSyncResult( - next_pos=next_pos, - lists={}, - rooms={}, - extensions=SlidingSyncResult.Extensions(), - ) diff --git a/synapse/types/handlers/policy_server.py b/synapse/types/handlers/policy_server.py new file mode 100644
index 0000000000..bfc09dabf4 --- /dev/null +++ b/synapse/types/handlers/policy_server.py
@@ -0,0 +1,16 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +RECOMMENDATION_OK = "ok" +RECOMMENDATION_SPAM = "spam" diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py new file mode 100644
index 0000000000..3ebd334a6d --- /dev/null +++ b/synapse/types/handlers/sliding_sync.py
@@ -0,0 +1,875 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +import logging +import typing +from collections import ChainMap +from enum import Enum +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Callable, + Dict, + Final, + Generic, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + cast, +) + +import attr + +from synapse._pydantic_compat import Extra +from synapse.api.constants import EventTypes +from synapse.events import EventBase +from synapse.types import ( + DeviceListUpdates, + JsonDict, + JsonMapping, + MultiWriterStreamToken, + Requester, + RoomStreamToken, + SlidingSyncStreamToken, + StrCollection, + StreamToken, + UserID, +) +from synapse.types.rest.client import SlidingSyncBody + +if TYPE_CHECKING: + from synapse.handlers.relations import BundledAggregations + +logger = logging.getLogger(__name__) + + +class SlidingSyncConfig(SlidingSyncBody): + """ + Inherit from `SlidingSyncBody` since we need all of the same fields and add a few + extra fields that we need in the handler + """ + + user: UserID + requester: Requester + + # Pydantic config + class Config: + # By default, ignore fields that we don't recognise. + extra = Extra.ignore + # By default, don't allow fields to be reassigned after parsing. + allow_mutation = False + # Allow custom types like `UserID` to be used in the model + arbitrary_types_allowed = True + + +class OperationType(Enum): + """ + Represents the operation types in a Sliding Sync window. + + Attributes: + SYNC: Sets a range of entries. Clients SHOULD discard what they previous knew about + entries in this range. + INSERT: Sets a single entry. If the position is not empty then clients MUST move + entries to the left or the right depending on where the closest empty space is. + DELETE: Remove a single entry. Often comes before an INSERT to allow entries to move + places. + INVALIDATE: Remove a range of entries. Clients MAY persist the invalidated range for + offline support, but they should be treated as empty when additional operations + which concern indexes in the range arrive from the server. + """ + + SYNC: Final = "SYNC" + INSERT: Final = "INSERT" + DELETE: Final = "DELETE" + INVALIDATE: Final = "INVALIDATE" + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SlidingSyncResult: + """ + The Sliding Sync result to be serialized to JSON for a response. + + Attributes: + next_pos: The next position token in the sliding window to request (next_batch). + lists: Sliding window API. A map of list key to list results. + rooms: Room subscription API. A map of room ID to room results. + extensions: Extensions API. A map of extension key to extension results. + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class RoomResult: + """ + Attributes: + name: Room name or calculated room name. + avatar: Room avatar + heroes: List of stripped membership events (containing `user_id` and optionally + `avatar_url` and `displayname`) for the users used to calculate the room name. + is_dm: Flag to specify whether the room is a direct-message room (most likely + between two people). + initial: Flag which is set when this is the first time the server is sending this + data on this connection. Clients can use this flag to replace or update + their local state. When there is an update, servers MUST omit this flag + entirely and NOT send "initial":false as this is wasteful on bandwidth. The + absence of this flag means 'false'. + unstable_expanded_timeline: Flag which is set if we're returning more historic + events due to the timeline limit having increased. See "XXX: Odd behavior" + comment ing `synapse.handlers.sliding_sync`. + required_state: The current state of the room + timeline: Latest events in the room. The last event is the most recent. + bundled_aggregations: A mapping of event ID to the bundled aggregations for + the timeline events above. This allows clients to show accurate reaction + counts (or edits, threads), even if some of the reaction events were skipped + over in a gappy sync. + stripped_state: Stripped state events (for rooms where the usre is + invited/knocked). Same as `rooms.invite.$room_id.invite_state` in sync v2, + absent on joined/left rooms + prev_batch: A token that can be passed as a start parameter to the + `/rooms/<room_id>/messages` API to retrieve earlier messages. + limited: True if there are more events than `timeline_limit` looking + backwards from the `response.pos` to the `request.pos`. + num_live: The number of timeline events which have just occurred and are not historical. + The last N events are 'live' and should be treated as such. This is mostly + useful to determine whether a given @mention event should make a noise or not. + Clients cannot rely solely on the absence of `initial: true` to determine live + events because if a room not in the sliding window bumps into the window because + of an @mention it will have `initial: true` yet contain a single live event + (with potentially other old events in the timeline). + bump_stamp: The `stream_ordering` of the last event according to the + `bump_event_types`. This helps clients sort more readily without them + needing to pull in a bunch of the timeline to determine the last activity. + `bump_event_types` is a thing because for example, we don't want display + name changes to mark the room as unread and bump it to the top. For + encrypted rooms, we just have to consider any activity as a bump because we + can't see the content and the client has to figure it out for themselves. + This may not be included if there hasn't been a change. + joined_count: The number of users with membership of join, including the client's + own user ID. (same as sync `v2 m.joined_member_count`) + invited_count: The number of users with membership of invite. (same as sync v2 + `m.invited_member_count`) + notification_count: The total number of unread notifications for this room. (same + as sync v2) + highlight_count: The number of unread notifications for this room with the highlight + flag set. (same as sync v2) + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class StrippedHero: + user_id: str + display_name: Optional[str] + avatar_url: Optional[str] + + name: Optional[str] + avatar: Optional[str] + heroes: Optional[List[StrippedHero]] + is_dm: bool + initial: bool + unstable_expanded_timeline: bool + # Should be empty for invite/knock rooms with `stripped_state` + required_state: List[EventBase] + # Should be empty for invite/knock rooms with `stripped_state` + timeline_events: List[EventBase] + bundled_aggregations: Optional[Dict[str, "BundledAggregations"]] + # Optional because it's only relevant to invite/knock rooms + stripped_state: List[JsonDict] + # Only optional because it won't be included for invite/knock rooms with `stripped_state` + prev_batch: Optional[StreamToken] + # Only optional because it won't be included for invite/knock rooms with `stripped_state` + limited: Optional[bool] + # Only optional because it won't be included for invite/knock rooms with `stripped_state` + num_live: Optional[int] + bump_stamp: Optional[int] + joined_count: Optional[int] + invited_count: Optional[int] + notification_count: int + highlight_count: int + + def __bool__(self) -> bool: + return ( + # If this is the first time the client is seeing the room, we should not filter it out + # under any circumstance. + self.initial + # We need to let the client know if any of the info has changed + or self.name is not None + or self.avatar is not None + or bool(self.heroes) + or self.joined_count is not None + or self.invited_count is not None + # We need to let the client know if there are any new events + or bool(self.required_state) + or bool(self.timeline_events) + or bool(self.stripped_state) + ) + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class SlidingWindowList: + """ + Attributes: + count: The total number of entries in the list. Always present if this list + is. + ops: The sliding list operations to perform. + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class Operation: + """ + Attributes: + op: The operation type to perform. + range: Which index positions are affected by this operation. These are + both inclusive. + room_ids: Which room IDs are affected by this operation. These IDs match + up to the positions in the `range`, so the last room ID in this list + matches the 9th index. The room data is held in a separate object. + """ + + op: OperationType + range: Tuple[int, int] + room_ids: List[str] + + count: int + ops: List[Operation] + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class Extensions: + """Responses for extensions + + Attributes: + to_device: The to-device extension (MSC3885) + e2ee: The E2EE device extension (MSC3884) + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ToDeviceExtension: + """The to-device extension (MSC3885) + + Attributes: + next_batch: The to-device stream token the client should use + to get more results + events: A list of to-device messages for the client + """ + + next_batch: str + events: Sequence[JsonMapping] + + def __bool__(self) -> bool: + return bool(self.events) + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class E2eeExtension: + """The E2EE device extension (MSC3884) + + Attributes: + device_list_updates: List of user_ids whose devices have changed or left (only + present on incremental syncs). + device_one_time_keys_count: Map from key algorithm to the number of + unclaimed one-time keys currently held on the server for this device. If + an algorithm is unlisted, the count for that algorithm is assumed to be + zero. If this entire parameter is missing, the count for all algorithms + is assumed to be zero. + device_unused_fallback_key_types: List of unused fallback key algorithms + for this device. + """ + + # Only present on incremental syncs + device_list_updates: Optional[DeviceListUpdates] + device_one_time_keys_count: Mapping[str, int] + device_unused_fallback_key_types: Sequence[str] + + def __bool__(self) -> bool: + # Note that "signed_curve25519" is always returned in key count responses + # regardless of whether we uploaded any keys for it. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. + # + # Also related: + # https://github.com/element-hq/element-android/issues/3725 and + # https://github.com/matrix-org/synapse/issues/10456 + default_otk = self.device_one_time_keys_count.get("signed_curve25519") + more_than_default_otk = len(self.device_one_time_keys_count) > 1 or ( + default_otk is not None and default_otk > 0 + ) + + return bool( + more_than_default_otk + or self.device_list_updates + or self.device_unused_fallback_key_types + ) + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class AccountDataExtension: + """The Account Data extension (MSC3959) + + Attributes: + global_account_data_map: Mapping from `type` to `content` of global + account data events. + account_data_by_room_map: Mapping from room_id to mapping of `type` to + `content` of room account data events. + """ + + global_account_data_map: Mapping[str, JsonMapping] + account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] + + def __bool__(self) -> bool: + return bool( + self.global_account_data_map or self.account_data_by_room_map + ) + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ReceiptsExtension: + """The Receipts extension (MSC3960) + + Attributes: + room_id_to_receipt_map: Mapping from room_id to `m.receipt` ephemeral + event (type, content) + """ + + room_id_to_receipt_map: Mapping[str, JsonMapping] + + def __bool__(self) -> bool: + return bool(self.room_id_to_receipt_map) + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class TypingExtension: + """The Typing Notification extension (MSC3961) + + Attributes: + room_id_to_typing_map: Mapping from room_id to `m.typing` ephemeral + event (type, content) + """ + + room_id_to_typing_map: Mapping[str, JsonMapping] + + def __bool__(self) -> bool: + return bool(self.room_id_to_typing_map) + + to_device: Optional[ToDeviceExtension] = None + e2ee: Optional[E2eeExtension] = None + account_data: Optional[AccountDataExtension] = None + receipts: Optional[ReceiptsExtension] = None + typing: Optional[TypingExtension] = None + + def __bool__(self) -> bool: + return bool( + self.to_device + or self.e2ee + or self.account_data + or self.receipts + or self.typing + ) + + next_pos: SlidingSyncStreamToken + lists: Mapping[str, SlidingWindowList] + rooms: Dict[str, RoomResult] + extensions: Extensions + + def __bool__(self) -> bool: + """Make the result appear empty if there are no updates. This is used + to tell if the notifier needs to wait for more events when polling for + events. + """ + # We don't include `self.lists` here, as a) `lists` is always non-empty even if + # there are no changes, and b) since we're sorting rooms by `stream_ordering` of + # the latest activity, anything that would cause the order to change would end + # up in `self.rooms` and cause us to send down the change. + return bool(self.rooms or self.extensions) + + @staticmethod + def empty(next_pos: SlidingSyncStreamToken) -> "SlidingSyncResult": + "Return a new empty result" + return SlidingSyncResult( + next_pos=next_pos, + lists={}, + rooms={}, + extensions=SlidingSyncResult.Extensions(), + ) + + +class StateValues: + """ + Understood values of the (type, state_key) tuple in `required_state`. + """ + + # Include all state events of the given type + WILDCARD: Final = "*" + # Lazy-load room membership events (include room membership events for any event + # `sender` or membership change target in the timeline). We only give special + # meaning to this value when it's a `state_key`. + LAZY: Final = "$LAZY" + # Subsitute with the requester's user ID. Typically used by clients to get + # the user's membership. + ME: Final = "$ME" + + +# We can't freeze this class because we want to update it in place with the +# de-duplicated data. +@attr.s(slots=True, auto_attribs=True, frozen=True) +class RoomSyncConfig: + """ + Holds the config for what data we should fetch for a room in the sync response. + + Attributes: + timeline_limit: The maximum number of events to return in the timeline. + + required_state_map: Map from state event type to state_keys requested for the + room. The values are close to `StateKey` but actually use a syntax where you + can provide `*` wildcard and `$LAZY` for lazy-loading room members. + """ + + timeline_limit: int + required_state_map: Mapping[str, AbstractSet[str]] + + @classmethod + def from_room_config( + cls, + room_params: SlidingSyncConfig.CommonRoomParameters, + ) -> "RoomSyncConfig": + """ + Create a `RoomSyncConfig` from a `SlidingSyncList`/`RoomSubscription` config. + + Args: + room_params: `SlidingSyncConfig.SlidingSyncList` or `SlidingSyncConfig.RoomSubscription` + """ + required_state_map: Dict[str, Set[str]] = {} + for ( + state_type, + state_key, + ) in room_params.required_state: + # If we already have a wildcard for this specific `state_key`, we don't need + # to add it since the wildcard already covers it. + if state_key in required_state_map.get(StateValues.WILDCARD, set()): + continue + + # If we already have a wildcard `state_key` for this `state_type`, we don't need + # to add anything else + if StateValues.WILDCARD in required_state_map.get(state_type, set()): + continue + + # If we're getting wildcards for the `state_type` and `state_key`, that's + # all that matters so get rid of any other entries + if state_type == StateValues.WILDCARD and state_key == StateValues.WILDCARD: + required_state_map = {StateValues.WILDCARD: {StateValues.WILDCARD}} + # We can break, since we don't need to add anything else + break + + # If we're getting a wildcard for the `state_type`, get rid of any other + # entries with the same `state_key`, since the wildcard will cover it already. + elif state_type == StateValues.WILDCARD: + # Get rid of any entries that match the `state_key` + # + # Make a copy so we don't run into an error: `dictionary changed size + # during iteration`, when we remove items + for ( + existing_state_type, + existing_state_key_set, + ) in list(required_state_map.items()): + # Make a copy so we don't run into an error: `Set changed size during + # iteration`, when we filter out and remove items + for existing_state_key in existing_state_key_set.copy(): + if existing_state_key == state_key: + existing_state_key_set.remove(state_key) + + # If we've the left the `set()` empty, remove it from the map + if existing_state_key_set == set(): + required_state_map.pop(existing_state_type, None) + + # If we're getting a wildcard `state_key`, get rid of any other state_keys + # for this `state_type` since the wildcard will cover it already. + if state_key == StateValues.WILDCARD: + required_state_map[state_type] = {state_key} + # Otherwise, just add it to the set + else: + if required_state_map.get(state_type) is None: + required_state_map[state_type] = {state_key} + else: + required_state_map[state_type].add(state_key) + + return cls( + timeline_limit=room_params.timeline_limit, + required_state_map=required_state_map, + ) + + def combine_room_sync_config( + self, other_room_sync_config: "RoomSyncConfig" + ) -> "RoomSyncConfig": + """ + Combine this `RoomSyncConfig` with another `RoomSyncConfig` and return the + superset union of the two. + """ + timeline_limit = self.timeline_limit + required_state_map = { + event_type: set(state_keys) + for event_type, state_keys in self.required_state_map.items() + } + + # Take the highest timeline limit + if timeline_limit < other_room_sync_config.timeline_limit: + timeline_limit = other_room_sync_config.timeline_limit + + # Union the required state + for ( + state_type, + state_key_set, + ) in other_room_sync_config.required_state_map.items(): + # If we already have a wildcard for everything, we don't need to add + # anything else + if StateValues.WILDCARD in required_state_map.get( + StateValues.WILDCARD, set() + ): + break + + # If we already have a wildcard `state_key` for this `state_type`, we don't need + # to add anything else + if StateValues.WILDCARD in required_state_map.get(state_type, set()): + continue + + # If we're getting wildcards for the `state_type` and `state_key`, that's + # all that matters so get rid of any other entries + if ( + state_type == StateValues.WILDCARD + and StateValues.WILDCARD in state_key_set + ): + required_state_map = {state_type: {StateValues.WILDCARD}} + # We can break, since we don't need to add anything else + break + + for state_key in state_key_set: + # If we already have a wildcard for this specific `state_key`, we don't need + # to add it since the wildcard already covers it. + if state_key in required_state_map.get(StateValues.WILDCARD, set()): + continue + + # If we're getting a wildcard for the `state_type`, get rid of any other + # entries with the same `state_key`, since the wildcard will cover it already. + if state_type == StateValues.WILDCARD: + # Get rid of any entries that match the `state_key` + # + # Make a copy so we don't run into an error: `dictionary changed size + # during iteration`, when we remove items + for existing_state_type, existing_state_key_set in list( + required_state_map.items() + ): + # Make a copy so we don't run into an error: `Set changed size during + # iteration`, when we filter out and remove items + for existing_state_key in existing_state_key_set.copy(): + if existing_state_key == state_key: + existing_state_key_set.remove(state_key) + + # If we've the left the `set()` empty, remove it from the map + if existing_state_key_set == set(): + required_state_map.pop(existing_state_type, None) + + # If we're getting a wildcard `state_key`, get rid of any other state_keys + # for this `state_type` since the wildcard will cover it already. + if state_key == StateValues.WILDCARD: + required_state_map[state_type] = {state_key} + break + # Otherwise, just add it to the set + else: + if required_state_map.get(state_type) is None: + required_state_map[state_type] = {state_key} + else: + required_state_map[state_type].add(state_key) + + return RoomSyncConfig(timeline_limit, required_state_map) + + def must_await_full_state( + self, + is_mine_id: Callable[[str], bool], + ) -> bool: + """ + Check if we have a we're only requesting `required_state` which is completely + satisfied even with partial state, then we don't need to `await_full_state` before + we can return it. + + Also see `StateFilter.must_await_full_state(...)` for comparison + + Partially-stated rooms should have all state events except for remote membership + events so if we require a remote membership event anywhere, then we need to + return `True` (requires full state). + + Args: + is_mine_id: a callable which confirms if a given state_key matches a mxid + of a local user + """ + wildcard_state_keys = self.required_state_map.get(StateValues.WILDCARD) + # Requesting *all* state in the room so we have to wait + if ( + wildcard_state_keys is not None + and StateValues.WILDCARD in wildcard_state_keys + ): + return True + + # If the wildcards don't refer to remote user IDs, then we don't need to wait + # for full state. + if wildcard_state_keys is not None: + for possible_user_id in wildcard_state_keys: + if not possible_user_id[0].startswith(UserID.SIGIL): + # Not a user ID + continue + + localpart_hostname = possible_user_id.split(":", 1) + if len(localpart_hostname) < 2: + # Not a user ID + continue + + if not is_mine_id(possible_user_id): + return True + + membership_state_keys = self.required_state_map.get(EventTypes.Member) + # We aren't requesting any membership events at all so the partial state will + # cover us. + if membership_state_keys is None: + return False + + # If we're requesting entirely local users, the partial state will cover us. + for user_id in membership_state_keys: + if user_id == StateValues.ME: + continue + # We're lazy-loading membership so we can just return the state we have. + # Lazy-loading means we include membership for any event `sender` or + # membership change target in the timeline but since we had to auth those + # timeline events, we will have the membership state for them (including + # from remote senders). + elif user_id == StateValues.LAZY: + continue + elif user_id == StateValues.WILDCARD: + return False + elif not is_mine_id(user_id): + return True + + # Local users only so the partial state will cover us. + return False + + +class HaveSentRoomFlag(Enum): + """Flag for whether we have sent the room down a sliding sync connection. + + The valid state changes here are: + NEVER -> LIVE + LIVE -> PREVIOUSLY + PREVIOUSLY -> LIVE + """ + + # The room has never been sent down (or we have forgotten we have sent it + # down). + NEVER = "never" + + # We have previously sent the room down, but there are updates that we + # haven't sent down. + PREVIOUSLY = "previously" + + # We have sent the room down and the client has received all updates. + LIVE = "live" + + +T = TypeVar("T", str, RoomStreamToken, MultiWriterStreamToken, int) + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class HaveSentRoom(Generic[T]): + """Whether we have sent the room data down a sliding sync connection. + + We are generic over the type of token used, e.g. `RoomStreamToken` or + `MultiWriterStreamToken`. + + Attributes: + status: Flag of if we have or haven't sent down the room + last_token: If the flag is `PREVIOUSLY` then this is non-null and + contains the last stream token of the last updates we sent down + the room, i.e. we still need to send everything since then to the + client. + """ + + status: HaveSentRoomFlag + last_token: Optional[T] + + @staticmethod + def live() -> "HaveSentRoom[T]": + return HaveSentRoom(HaveSentRoomFlag.LIVE, None) + + @staticmethod + def previously(last_token: T) -> "HaveSentRoom[T]": + """Constructor for `PREVIOUSLY` flag.""" + return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token) + + @staticmethod + def never() -> "HaveSentRoom[T]": + # We use a singleton to avoid repeatedly instantiating new `never` + # values. + return _HAVE_SENT_ROOM_NEVER + + +_HAVE_SENT_ROOM_NEVER: HaveSentRoom[Any] = HaveSentRoom(HaveSentRoomFlag.NEVER, None) + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class RoomStatusMap(Generic[T]): + """For a given stream, e.g. events, records what we have or have not sent + down for that stream in a given room.""" + + # `room_id` -> `HaveSentRoom` + _statuses: Mapping[str, HaveSentRoom[T]] = attr.Factory(dict) + + def have_sent_room(self, room_id: str) -> HaveSentRoom[T]: + """Return whether we have previously sent the room down""" + return self._statuses.get(room_id, HaveSentRoom.never()) + + def get_mutable(self) -> "MutableRoomStatusMap[T]": + """Get a mutable copy of this state.""" + return MutableRoomStatusMap( + statuses=self._statuses, + ) + + def copy(self) -> "RoomStatusMap[T]": + """Make a copy of the class. Useful for converting from a mutable to + immutable version.""" + + return RoomStatusMap(statuses=dict(self._statuses)) + + def __len__(self) -> int: + return len(self._statuses) + + +class MutableRoomStatusMap(RoomStatusMap[T]): + """A mutable version of `RoomStatusMap`""" + + # We use a ChainMap here so that we can easily track what has been updated + # and what hasn't. Note that when we persist the per connection state this + # will get flattened to a normal dict (via calling `.copy()`) + _statuses: typing.ChainMap[str, HaveSentRoom[T]] + + def __init__( + self, + statuses: Mapping[str, HaveSentRoom[T]], + ) -> None: + # ChainMap requires a mutable mapping, but we're not actually going to + # mutate it. + statuses = cast(MutableMapping, statuses) + + super().__init__( + statuses=ChainMap({}, statuses), + ) + + def get_updates(self) -> Mapping[str, HaveSentRoom[T]]: + """Return only the changes that were made""" + return self._statuses.maps[0] + + def record_sent_rooms(self, room_ids: StrCollection) -> None: + """Record that we have sent these rooms in the response""" + for room_id in room_ids: + current_status = self._statuses.get(room_id, HaveSentRoom.never()) + if current_status.status == HaveSentRoomFlag.LIVE: + continue + + self._statuses[room_id] = HaveSentRoom.live() + + def record_unsent_rooms(self, room_ids: StrCollection, from_token: T) -> None: + """Record that we have not sent these rooms in the response, but there + have been updates. + """ + # Whether we add/update the entries for unsent rooms depends on the + # existing entry: + # - LIVE: We have previously sent down everything up to + # `last_room_token, so we update the entry to be `PREVIOUSLY` with + # `last_room_token`. + # - PREVIOUSLY: We have previously sent down everything up to *a* + # given token, so we don't need to update the entry. + # - NEVER: We have never previously sent down the room, and we haven't + # sent anything down this time either so we leave it as NEVER. + + for room_id in room_ids: + current_status = self._statuses.get(room_id, HaveSentRoom.never()) + if current_status.status != HaveSentRoomFlag.LIVE: + continue + + self._statuses[room_id] = HaveSentRoom.previously(from_token) + + +@attr.s(auto_attribs=True, frozen=True) +class PerConnectionState: + """The per-connection state. A snapshot of what we've sent down the + connection before. + + Currently, we track whether we've sent down various aspects of a given room + before. + + We use the `rooms` field to store the position in the events stream for each + room that we've previously sent to the client before. On the next request + that includes the room, we can then send only what's changed since that + recorded position. + + Same goes for the `receipts` field so we only need to send the new receipts + since the last time you made a sync request. + + Attributes: + rooms: The status of each room for the events stream. + receipts: The status of each room for the receipts stream. + room_configs: Map from room_id to the `RoomSyncConfig` of all + rooms that we have previously sent down. + """ + + rooms: RoomStatusMap[RoomStreamToken] = attr.Factory(RoomStatusMap) + receipts: RoomStatusMap[MultiWriterStreamToken] = attr.Factory(RoomStatusMap) + account_data: RoomStatusMap[int] = attr.Factory(RoomStatusMap) + + room_configs: Mapping[str, RoomSyncConfig] = attr.Factory(dict) + + def get_mutable(self) -> "MutablePerConnectionState": + """Get a mutable copy of this state.""" + room_configs = cast(MutableMapping[str, RoomSyncConfig], self.room_configs) + + return MutablePerConnectionState( + rooms=self.rooms.get_mutable(), + receipts=self.receipts.get_mutable(), + account_data=self.account_data.get_mutable(), + room_configs=ChainMap({}, room_configs), + ) + + def copy(self) -> "PerConnectionState": + return PerConnectionState( + rooms=self.rooms.copy(), + receipts=self.receipts.copy(), + account_data=self.account_data.copy(), + room_configs=dict(self.room_configs), + ) + + def __len__(self) -> int: + return len(self.rooms) + len(self.receipts) + len(self.room_configs) + + +@attr.s(auto_attribs=True) +class MutablePerConnectionState(PerConnectionState): + """A mutable version of `PerConnectionState`""" + + rooms: MutableRoomStatusMap[RoomStreamToken] + receipts: MutableRoomStatusMap[MultiWriterStreamToken] + account_data: MutableRoomStatusMap[int] + + room_configs: typing.ChainMap[str, RoomSyncConfig] + + def has_updates(self) -> bool: + return ( + bool(self.rooms.get_updates()) + or bool(self.receipts.get_updates()) + or bool(self.account_data.get_updates()) + or bool(self.get_room_config_updates()) + ) + + def get_room_config_updates(self) -> Mapping[str, RoomSyncConfig]: + """Get updates to the room sync config""" + return self.room_configs.maps[0] diff --git a/synapse/types/rest/__init__.py b/synapse/types/rest/__init__.py
index 2b6f5ed35a..183831e79a 100644 --- a/synapse/types/rest/__init__.py +++ b/synapse/types/rest/__init__.py
@@ -18,14 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING - -from synapse._pydantic_compat import HAS_PYDANTIC_V2 - -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import BaseModel, Extra -else: - from pydantic import BaseModel, Extra +from synapse._pydantic_compat import BaseModel, Extra class RequestBodyModel(BaseModel): diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index 93b537ab7b..2d386bfe53 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py
@@ -20,31 +20,16 @@ # from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union -from synapse._pydantic_compat import HAS_PYDANTIC_V2 - -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import ( - Extra, - StrictBool, - StrictInt, - StrictStr, - conint, - constr, - validator, - ) -else: - from pydantic import ( - Extra, - StrictBool, - StrictInt, - StrictStr, - conint, - constr, - validator, - ) - +from synapse._pydantic_compat import ( + Extra, + StrictBool, + StrictInt, + StrictStr, + conint, + constr, + validator, +) from synapse.types.rest import RequestBodyModel -from synapse.util.threepids import validate_email class AuthenticationData(RequestBodyModel): @@ -76,33 +61,6 @@ else: ) -class ThreepidRequestTokenBody(RequestBodyModel): - client_secret: ClientSecretStr - id_server: Optional[StrictStr] - id_access_token: Optional[StrictStr] - next_link: Optional[StrictStr] - send_attempt: StrictInt - - @validator("id_access_token", always=True) - def token_required_for_identity_server( - cls, token: Optional[str], values: Dict[str, object] - ) -> Optional[str]: - if values.get("id_server") is not None and token is None: - raise ValueError("id_access_token is required if an id_server is supplied.") - return token - - -class EmailRequestTokenBody(ThreepidRequestTokenBody): - email: StrictStr - - # Canonicalise the email address. The addresses are all stored canonicalised - # in the database. This allows the user to reset his password without having to - # know the exact spelling (eg. upper and lower case) of address in the database. - # Without this, an email stored in the database as "foo@bar.com" would cause - # user requests for "FOO@bar.com" to raise a Not Found error. - _email_validator = validator("email", allow_reuse=True)(validate_email) - - if TYPE_CHECKING: ISO3116_1_Alpha_2 = StrictStr else: @@ -110,11 +68,6 @@ else: ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True) -class MsisdnRequestTokenBody(ThreepidRequestTokenBody): - country: ISO3116_1_Alpha_2 - phone_number: StrictStr - - class SlidingSyncBody(RequestBodyModel): """ Sliding Sync API request body. @@ -268,7 +221,9 @@ class SlidingSyncBody(RequestBodyModel): if TYPE_CHECKING: ranges: Optional[List[Tuple[int, int]]] = None else: - ranges: Optional[List[Tuple[conint(ge=0, strict=True), conint(ge=0, strict=True)]]] = None # type: ignore[valid-type] + ranges: Optional[ + List[Tuple[conint(ge=0, strict=True), conint(ge=0, strict=True)]] + ] = None # type: ignore[valid-type] slow_get_all_rooms: Optional[StrictBool] = False filters: Optional[Filters] = None @@ -382,13 +337,15 @@ class SlidingSyncBody(RequestBodyModel): receipts: Optional[ReceiptsExtension] = None typing: Optional[TypingExtension] = None - conn_id: Optional[str] + conn_id: Optional[StrictStr] # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884 if TYPE_CHECKING: lists: Optional[Dict[str, SlidingSyncList]] = None else: - lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type] + lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = ( + None # type: ignore[valid-type] + ) room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None extensions: Optional[Extensions] = None diff --git a/synapse/types/state.py b/synapse/types/state.py
index c958a95701..6420e050a5 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py
@@ -68,15 +68,23 @@ class StateFilter: include_others: bool = False def __attrs_post_init__(self) -> None: - # If `include_others` is set we canonicalise the filter by removing - # wildcards from the types dictionary if self.include_others: + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + # this is needed to work around the fact that StateFilter is frozen object.__setattr__( self, "types", immutabledict({k: v for k, v in self.types.items() if v is not None}), ) + else: + # Otherwise we remove entries where the value is the empty set. + object.__setattr__( + self, + "types", + immutabledict({k: v for k, v in self.types.items() if v is None or v}), + ) @staticmethod def all() -> "StateFilter": @@ -454,7 +462,7 @@ class StateFilter: new_types.update({state_type: set() for state_type in minus_wildcards}) # insert the plus wildcards - new_types.update({state_type: None for state_type in plus_wildcards}) + new_types.update(dict.fromkeys(plus_wildcards)) # insert the specific state keys for state_type, state_key in plus_state_keys: @@ -503,13 +511,19 @@ class StateFilter: # - if so, which event types are excluded? ('excludes') # - which entire event types to include ('wildcards') # - which concrete state keys to include ('concrete state keys') - (self_all, self_excludes), ( - self_wildcards, - self_concrete_keys, + ( + (self_all, self_excludes), + ( + self_wildcards, + self_concrete_keys, + ), ) = self._decompose_into_four_parts() - (other_all, other_excludes), ( - other_wildcards, - other_concrete_keys, + ( + (other_all, other_excludes), + ( + other_wildcards, + other_concrete_keys, + ), ) = other._decompose_into_four_parts() # Start with an estimate of the difference based on self @@ -610,6 +624,13 @@ class StateFilter: return False + def __bool__(self) -> bool: + """Returns true if this state filter will match any state, or false if + this is the empty filter""" + if self.include_others: + return True + return bool(self.types) + _ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True) _ALL_NON_MEMBER_STATE_FILTER = StateFilter( diff --git a/synapse/types/storage/__init__.py b/synapse/types/storage/__init__.py new file mode 100644
index 0000000000..378a15e038 --- /dev/null +++ b/synapse/types/storage/__init__.py
@@ -0,0 +1,56 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# +# Originally licensed under the Apache License, Version 2.0: +# <http://www.apache.org/licenses/LICENSE-2.0>. +# +# [This file includes modifications made by New Vector Limited] +# +# + + +class _BackgroundUpdates: + EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" + DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" + POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2" + INDEX_STREAM_ORDERING2 = "index_stream_ordering2" + INDEX_STREAM_ORDERING2_CONTAINS_URL = "index_stream_ordering2_contains_url" + INDEX_STREAM_ORDERING2_ROOM_ORDER = "index_stream_ordering2_room_order" + INDEX_STREAM_ORDERING2_ROOM_STREAM = "index_stream_ordering2_room_stream" + INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts" + REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column" + + EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows" + EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index" + + EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections" + + EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index" + + SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE = ( + "sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update" + ) + SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE = "sliding_sync_joined_rooms_bg_update" + SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE = ( + "sliding_sync_membership_snapshots_bg_update" + ) + SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_FIX_FORGOTTEN_COLUMN_BG_UPDATE = ( + "sliding_sync_membership_snapshots_fix_forgotten_column_bg_update" + ) + + MARK_UNREFERENCED_STATE_GROUPS_FOR_DELETION_BG_UPDATE = ( + "mark_unreferenced_state_groups_for_deletion_bg_update" + ) + + FIXUP_MAX_DEPTH_CAP = "fixup_max_depth_cap" diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 70139beef2..e596e1ed20 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -41,6 +41,7 @@ from typing import ( Hashable, Iterable, List, + Literal, Optional, Set, Tuple, @@ -51,7 +52,7 @@ from typing import ( ) import attr -from typing_extensions import Concatenate, Literal, ParamSpec +from typing_extensions import Concatenate, ParamSpec, Unpack from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -61,6 +62,7 @@ from twisted.python.failure import Failure from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, + run_coroutine_in_background, run_in_background, ) from synapse.util import Clock @@ -344,6 +346,7 @@ T1 = TypeVar("T1") T2 = TypeVar("T2") T3 = TypeVar("T3") T4 = TypeVar("T4") +T5 = TypeVar("T5") @overload @@ -402,6 +405,112 @@ def gather_results( # type: ignore[misc] return deferred.addCallback(tuple) +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]]]], +) -> Tuple[Optional[T1]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2], Optional[T3]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + Optional[Coroutine[Any, Any, T4]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4]]: ... + + +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + Optional[Coroutine[Any, Any, T4]], + Optional[Coroutine[Any, Any, T5]], + ] + ], +) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ... + + +async def gather_optional_coroutines( + *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]], +) -> Tuple[Optional[T1], ...]: + """Helper function that allows waiting on multiple coroutines at once. + + The return value is a tuple of the return values of the coroutines in order. + + If a `None` is passed instead of a coroutine, it will be ignored and a None + is returned in the tuple. + + Note: For typechecking we need to have an explicit overload for each + distinct number of coroutines passed in. If you see type problems, it's + likely because you're using many arguments and you need to add a new + overload above. + """ + + try: + results = await make_deferred_yieldable( + defer.gatherResults( + [ + run_coroutine_in_background(coroutine) + for coroutine in coroutines + if coroutine is not None + ], + consumeErrors=True, + ) + ) + + results_iter = iter(results) + return tuple( + next(results_iter) if coroutine is not None else None + for coroutine in coroutines + ) + except defer.FirstError as dfe: + # unwrap the error from defer.gatherResults. + + # The raised exception's traceback only includes func() etc if + # the 'await' happens before the exception is thrown - ie if the failure + # happens *asynchronously* - otherwise Twisted throws away the traceback as it + # could be large. + # + # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe + # we could throw Twisted into the fires of Mordor. + + # suppress exception chaining, because the FirstError doesn't tell us anything + # very interesting. + assert isinstance(dfe.subFailure.value, BaseException) + raise dfe.subFailure.value from None + + @attr.s(slots=True, auto_attribs=True) class _LinearizerEntry: # The number of things executing. @@ -885,3 +994,46 @@ class AwakenableSleeper: # Cancel the sleep if we were woken up if call.active(): call.cancel() + + +class DeferredEvent: + """Like threading.Event but for async code""" + + def __init__(self, reactor: IReactorTime) -> None: + self._reactor = reactor + self._deferred: "defer.Deferred[None]" = defer.Deferred() + + def set(self) -> None: + if not self._deferred.called: + self._deferred.callback(None) + + def clear(self) -> None: + if self._deferred.called: + self._deferred = defer.Deferred() + + def is_set(self) -> bool: + return self._deferred.called + + async def wait(self, timeout_seconds: float) -> bool: + if self.is_set(): + return True + + # Create a deferred that gets called in N seconds + sleep_deferred: "defer.Deferred[None]" = defer.Deferred() + call = self._reactor.callLater(timeout_seconds, sleep_deferred.callback, None) + + try: + await make_deferred_yieldable( + defer.DeferredList( + [sleep_deferred, self._deferred], + fireOnOneCallback=True, + fireOnOneErrback=True, + consumeErrors=True, + ) + ) + finally: + # Cancel the sleep if we were woken up + if call.active(): + call.cancel() + + return self.is_set() diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 1e6696332f..14bd3ba3b0 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py
@@ -21,10 +21,19 @@ import enum import logging import threading -from typing import Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union +from typing import ( + Dict, + Generic, + Iterable, + Literal, + Optional, + Set, + Tuple, + TypeVar, + Union, +) import attr -from typing_extensions import Literal from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 8017c031ee..3198fdd2ed 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py
@@ -21,10 +21,9 @@ import logging from collections import OrderedDict -from typing import Any, Generic, Iterable, Optional, TypeVar, Union, overload +from typing import Any, Generic, Iterable, Literal, Optional, TypeVar, Union, overload import attr -from typing_extensions import Literal from twisted.internet import defer diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 481a1a621e..2e5efa3a52 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py
@@ -34,6 +34,7 @@ from typing import ( Generic, Iterable, List, + Literal, Optional, Set, Tuple, @@ -44,8 +45,6 @@ from typing import ( overload, ) -from typing_extensions import Literal - from twisted.internet import reactor from twisted.internet.interfaces import IReactorTime diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 96b7ca83dc..54b99134b9 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py
@@ -101,7 +101,13 @@ class ResponseCache(Generic[KV]): used rather than trying to compute a new response. """ - def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): + def __init__( + self, + clock: Clock, + name: str, + timeout_ms: float = 0, + enable_logging: bool = True, + ): self._result_cache: Dict[KV, ResponseCacheEntry] = {} self.clock = clock @@ -109,6 +115,7 @@ class ResponseCache(Generic[KV]): self._name = name self._metrics = register_cache("response_cache", name, self, resizable=False) + self._enable_logging = enable_logging def size(self) -> int: return len(self._result_cache) @@ -246,9 +253,12 @@ class ResponseCache(Generic[KV]): """ entry = self._get(key) if not entry: - logger.debug( - "[%s]: no cached result for [%s], calculating new one", self._name, key - ) + if self._enable_logging: + logger.debug( + "[%s]: no cached result for [%s], calculating new one", + self._name, + key, + ) context = ResponseCacheContext(cache_key=key) if cache_context: kwargs["cache_context"] = context @@ -269,12 +279,15 @@ class ResponseCache(Generic[KV]): return await make_deferred_yieldable(entry.result.observe()) result = entry.result.observe() - if result.called: - logger.info("[%s]: using completed cached result for [%s]", self._name, key) - else: - logger.info( - "[%s]: using incomplete cached result for [%s]", self._name, key - ) + if self._enable_logging: + if result.called: + logger.info( + "[%s]: using completed cached result for [%s]", self._name, key + ) + else: + logger.info( + "[%s]: using incomplete cached result for [%s]", self._name, key + ) span_context = entry.opentracing_span_context with start_active_span_follows_from( diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 16fcb00206..5ac8643eef 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py
@@ -142,9 +142,9 @@ class StreamChangeCache: """ assert isinstance(stream_pos, int) - # _cache is not valid at or before the earliest known stream position, so + # _cache is not valid before the earliest known stream position, so # return that the entity has changed. - if stream_pos <= self._earliest_known_stream_pos: + if stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() return True @@ -186,7 +186,7 @@ class StreamChangeCache: This will be all entities if the given stream position is at or earlier than the earliest known stream position. """ - if not self._cache or stream_pos <= self._earliest_known_stream_pos: + if not self._cache or stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() return set(entities) @@ -238,9 +238,9 @@ class StreamChangeCache: """ assert isinstance(stream_pos, int) - # _cache is not valid at or before the earliest known stream position, so + # _cache is not valid before the earliest known stream position, so # return that an entity has changed. - if stream_pos <= self._earliest_known_stream_pos: + if stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() return True @@ -270,9 +270,9 @@ class StreamChangeCache: """ assert isinstance(stream_pos, int) - # _cache is not valid at or before the earliest known stream position, so + # _cache is not valid before the earliest known stream position, so # return None to mark that it is unknown if an entity has changed. - if stream_pos <= self._earliest_known_stream_pos: + if stream_pos < self._earliest_known_stream_pos: return AllEntitiesChangedResult(None) changed_entities: List[EntityType] = [] @@ -314,6 +314,15 @@ class StreamChangeCache: self._entity_to_key[entity] = stream_pos self._evict() + def all_entities_changed(self, stream_pos: int) -> None: + """ + Mark all entities as changed. This is useful when the cache is invalidated and + there may be some potential change for all of the entities. + """ + self._cache.clear() + self._entity_to_key.clear() + self._earliest_known_stream_pos = stream_pos + def _evict(self) -> None: """ Ensure the cache has not exceeded the maximum size. diff --git a/synapse/util/events.py b/synapse/util/events.py new file mode 100644
index 0000000000..ad9b946578 --- /dev/null +++ b/synapse/util/events.py
@@ -0,0 +1,29 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# +# + +from synapse.util.stringutils import random_string + + +def generate_fake_event_id() -> str: + """ + Generate an event ID from random ASCII characters. + + This is primarily useful for generating fake event IDs in response to + requests from shadow-banned users. + + Returns: + A string intended to look like an event ID, but with no actual meaning. + """ + return "$" + random_string(43) diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index b73f690b88..0a6a30aab2 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py
@@ -30,14 +30,13 @@ from typing import ( Iterator, List, Mapping, + Protocol, Set, Sized, Tuple, TypeVar, ) -from typing_extensions import Protocol - T = TypeVar("T") S = TypeVar("S", bound="_SelfSlice") @@ -115,7 +114,7 @@ def sorted_topologically( # This is implemented by Kahn's algorithm. - degree_map = {node: 0 for node in nodes} + degree_map = dict.fromkeys(nodes, 0) reverse_graph: Dict[T, Set[T]] = {} for node, edges in graph.items(): @@ -165,7 +164,7 @@ def sorted_topologically_batched( persisted. """ - degree_map = {node: 0 for node in nodes} + degree_map = dict.fromkeys(nodes, 0) reverse_graph: Dict[T, Set[T]] = {} for node, edges in graph.items(): diff --git a/synapse/util/linked_list.py b/synapse/util/linked_list.py
index e9a5fff211..87f801c0cf 100644 --- a/synapse/util/linked_list.py +++ b/synapse/util/linked_list.py
@@ -19,8 +19,7 @@ # # -"""A circular doubly linked list implementation. -""" +"""A circular doubly linked list implementation.""" import threading from typing import Generic, Optional, Type, TypeVar diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index 84ae226207..6fa15543ec 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py
@@ -22,12 +22,11 @@ """Utilities for manipulating macaroons""" -from typing import Callable, Optional +from typing import Callable, Literal, Optional import attr import pymacaroons from pymacaroons.exceptions import MacaroonVerificationFailedException -from typing_extensions import Literal from synapse.util import Clock, stringutils diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 517e79ce5f..6a389f7a7e 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py
@@ -22,10 +22,19 @@ import logging from functools import wraps from types import TracebackType -from typing import Awaitable, Callable, Dict, Generator, Optional, Type, TypeVar +from typing import ( + Awaitable, + Callable, + Dict, + Generator, + Optional, + Protocol, + Type, + TypeVar, +) from prometheus_client import CollectorRegistry, Counter, Metric -from typing_extensions import Concatenate, ParamSpec, Protocol +from typing_extensions import Concatenate, ParamSpec from synapse.logging.context import ( ContextResourceUsage, @@ -110,7 +119,7 @@ def measure_func( """ def wrapper( - func: Callable[Concatenate[HasClock, P], Awaitable[R]] + func: Callable[Concatenate[HasClock, P], Awaitable[R]], ) -> Callable[P, Awaitable[R]]: block_name = func.__name__ if name is None else name diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py deleted file mode 100644
index b6a784f0bc..0000000000 --- a/synapse/util/msisdn.py +++ /dev/null
@@ -1,51 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2017 Vector Creations Ltd -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -import phonenumbers - -from synapse.api.errors import SynapseError - - -def phone_number_to_msisdn(country: str, number: str) -> str: - """ - Takes an ISO-3166-1 2 letter country code and phone number and - returns an msisdn representing the canonical version of that - phone number. - - As an example, if `country` is "GB" and `number` is "7470674927", this - function will return "447470674927". - - Args: - country: ISO-3166-1 2 letter country code - number: Phone number in a national or international format - - Returns: - The canonical form of the phone number, as an msisdn. - Raises: - SynapseError if the number could not be parsed. - """ - try: - phoneNumber = phonenumbers.parse(number, country) - except phonenumbers.NumberParseException: - raise SynapseError(400, "Unable to parse phone number") - return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[ - 1: - ] diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 46dad32156..beea4d2888 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py
@@ -50,7 +50,7 @@ def do_patch() -> None: return def new_inline_callbacks( - f: Callable[P, Generator["Deferred[object]", object, T]] + f: Callable[P, Generator["Deferred[object]", object, T]], ) -> Callable[P, "Deferred[T]"]: @functools.wraps(f) def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]": @@ -162,7 +162,7 @@ def _check_yield_points( d = result.throwExceptionIntoGenerator(gen) else: d = gen.send(result) - except (StopIteration, defer._DefGen_Return) as e: + except StopIteration as e: if current_context() != expected_context: # This happens when the context is lost sometime *after* the # final yield and returning. E.g. we forgot to yield on a @@ -183,7 +183,7 @@ def _check_yield_points( ) ) changes.append(err) - # The `StopIteration` or `_DefGen_Return` contains the return value from the + # The `StopIteration` contains the return value from the # generator. return cast(T, e.value) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 8ead72bb7a..3f067b792c 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py
@@ -103,7 +103,7 @@ _rate_limiter_instances_lock = threading.Lock() def _get_counts_from_rate_limiter_instance( - count_func: Callable[["FederationRateLimiter"], int] + count_func: Callable[["FederationRateLimiter"], int], ) -> Mapping[Tuple[str, ...], int]: """Returns a count of something (slept/rejected hosts) by (metrics_name)""" # Cast to a list to prevent it changing while the Prometheus diff --git a/synapse/util/rust.py b/synapse/util/rust.py
index 0e35d6d188..37f43459f1 100644 --- a/synapse/util/rust.py +++ b/synapse/util/rust.py
@@ -19,9 +19,12 @@ # # +import json import os -import sys +import urllib.parse from hashlib import blake2b +from importlib.metadata import Distribution, PackageNotFoundError +from typing import Optional import synapse from synapse.synapse_rust import get_rust_file_digest @@ -32,22 +35,17 @@ def check_rust_lib_up_to_date() -> None: be rebuilt. """ - if not _dist_is_editable(): - return - - synapse_dir = os.path.dirname(synapse.__file__) - synapse_root = os.path.abspath(os.path.join(synapse_dir, "..")) - - # Double check we've not gone into site-packages... - if os.path.basename(synapse_root) == "site-packages": - return - - # ... and it looks like the root of a python project. - if not os.path.exists("pyproject.toml"): - return + # Get the location of the editable install. + synapse_root = get_synapse_source_directory() + if synapse_root is None: + return None # Get the hash of all Rust source files - hash = _hash_rust_files_in_directory(os.path.join(synapse_root, "rust", "src")) + rust_path = os.path.join(synapse_root, "rust", "src") + if not os.path.exists(rust_path): + return None + + hash = _hash_rust_files_in_directory(rust_path) if hash != get_rust_file_digest(): raise Exception("Rust module outdated. Please rebuild using `poetry install`") @@ -82,10 +80,55 @@ def _hash_rust_files_in_directory(directory: str) -> str: return hasher.hexdigest() -def _dist_is_editable() -> bool: - """Is distribution an editable install?""" - for path_item in sys.path: - egg_link = os.path.join(path_item, "matrix-synapse.egg-link") - if os.path.isfile(egg_link): - return True - return False +def get_synapse_source_directory() -> Optional[str]: + """Try and find the source directory of synapse for editable installs (like + those used in development). + + Returns None if not an editable install (or otherwise can't find the source + directory). + """ + + # Try and find the installed matrix-synapse package. + try: + package = Distribution.from_name("matrix-synapse") + except PackageNotFoundError: + # The package is not found, so it's not installed and so must be being + # pulled out from a local directory (usually the current one). + synapse_dir = os.path.dirname(synapse.__file__) + synapse_root = os.path.abspath(os.path.join(synapse_dir, "..")) + + # Double check we've not gone into site-packages... + if os.path.basename(synapse_root) == "site-packages": + return None + + # ... and it looks like the root of a python project. + if not os.path.exists("pyproject.toml"): + return None + + return synapse_root + + # Read the `direct_url.json` metadata for the package. This won't exist for + # packages installed via a repository/etc. + # c.f. https://packaging.python.org/en/latest/specifications/direct-url/ + direct_url_json = package.read_text("direct_url.json") + if direct_url_json is None: + return None + + # c.f. https://packaging.python.org/en/latest/specifications/direct-url/ for + # the format + direct_url_dict: dict = json.loads(direct_url_json) + + # `url` must exist as a key, and point to where we fetched the repo from. + project_url = urllib.parse.urlparse(direct_url_dict["url"]) + + # If its not a local file then we must have built the rust libs either a) + # after we downloaded the package, or b) we built the download wheel. + if project_url.scheme != "file": + return None + + # And finally if its not an editable install then the files can't have + # changed since we installed the package. + if not direct_url_dict.get("dir_info", {}).get("editable", False): + return None + + return project_url.path diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 13ff54b669..32b5bc00c9 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py
@@ -43,6 +43,14 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$") # MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") +# https://spec.matrix.org/v1.13/appendices/#common-namespaced-identifier-grammar +# +# At least one character, less than or equal to 255 characters. Must start with +# a-z, the rest is a-z, 0-9, -, _, or .. +# +# This doesn't check anything about validity of namespaces. +NAMESPACED_GRAMMAR = re.compile(r"^[a-z][a-z0-9_.-]{0,254}$") + def random_string(length: int) -> str: """Generate a cryptographically secure string of random letters. @@ -68,6 +76,10 @@ def is_ascii(s: bytes) -> bool: return True +def is_namedspaced_grammar(s: str) -> bool: + return bool(NAMESPACED_GRAMMAR.match(s)) + + def assert_valid_client_secret(client_secret: str) -> None: """Validate that a given string matches the client_secret defined by the spec""" if ( diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index 448960b297..4683d09cd7 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py
@@ -46,33 +46,43 @@ logger = logging.getLogger(__name__) class TaskScheduler: """ - This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background` - to launch a background task, or Twisted `deferLater` if we want to do so later on. - - The problem with that is that the tasks will just stop and never be resumed if synapse - is stopped for whatever reason. - - How this works: - - A function mapped to a named action should first be registered with `register_action`. - This function will be called when trying to resuming tasks after a synapse shutdown, - so this registration should happen when synapse is initialised, NOT right before scheduling - a task. - - A task can then be launched using this named action with `schedule_task`. A `params` dict - can be passed, and it will be available to the registered function when launched. This task - can be launch either now-ish, or later on by giving a `timestamp` parameter. - - The function may call `update_task` at any time to update the `result` of the task, - and this can be used to resume the task at a specific point and/or to convey a result to - the code launching the task. - You can also specify the `result` (and/or an `error`) when returning from the function. - - The reconciliation loop runs every minute, so this is not a precise scheduler. - There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already - full. In this regard, please take great care that scheduled tasks can actually finished. - For now there is no mechanism to stop a running task if it is stuck. - - Tasks will be run on the worker specified with `run_background_tasks_on` config, - or the main one by default. + This is a simple task scheduler designed for resumable tasks. Normally, + you'd use `run_in_background` to start a background task or Twisted's + `deferLater` if you want to run it later. + + The issue is that these tasks stop completely and won't resume if Synapse is + shut down for any reason. + + Here's how it works: + + - Register an Action: First, you need to register a function to a named + action using `register_action`. This function will be called to resume tasks + after a Synapse shutdown. Make sure to register it when Synapse initializes, + not right before scheduling the task. + + - Schedule a Task: You can launch a task linked to the named action + using `schedule_task`. You can pass a `params` dictionary, which will be + passed to the registered function when it's executed. Tasks can be scheduled + to run either immediately or later by specifying a `timestamp`. + + - Update Task: The function handling the task can call `update_task` at + any point to update the task's `result`. This lets you resume the task from + a specific point or pass results back to the code that scheduled it. When + the function completes, you can also return a `result` or an `error`. + + Things to keep in mind: + + - The reconciliation loop runs every minute, so this is not a high-precision + scheduler. + + - Only 10 tasks can run at the same time. If the pool is full, tasks may be + delayed. Make sure your scheduled tasks can actually finish. + + - Currently, there's no way to stop a task if it gets stuck. + + - Tasks will run on the worker defined by the `run_background_tasks_on` + setting in your configuration. If no worker is specified, they'll run on + the main one by default. """ # Precision of the scheduler, evaluation of tasks to run will only happen @@ -157,7 +167,7 @@ class TaskScheduler: params: Optional[JsonMapping] = None, ) -> str: """Schedule a new potentially resumable task. A function matching the specified - `action` should have be registered with `register_action` before the task is run. + `action` should've been registered with `register_action` before the task is run. Args: action: the name of a previously registered action @@ -174,9 +184,10 @@ class TaskScheduler: The id of the scheduled task """ status = TaskStatus.SCHEDULED + start_now = False if timestamp is None or timestamp < self._clock.time_msec(): timestamp = self._clock.time_msec() - status = TaskStatus.ACTIVE + start_now = True task = ScheduledTask( random_string(16), @@ -190,9 +201,11 @@ class TaskScheduler: ) await self._store.insert_scheduled_task(task) - if status == TaskStatus.ACTIVE: + # If the task is ready to run immediately, run the scheduling algorithm now + # rather than waiting + if start_now: if self._run_background_tasks: - await self._launch_task(task) + self._launch_scheduled_tasks() else: self._hs.get_replication_command_handler().send_new_active_task(task.id) @@ -207,15 +220,15 @@ class TaskScheduler: result: Optional[JsonMapping] = None, error: Optional[str] = None, ) -> bool: - """Update some task associated values. This is exposed publicly so it can - be used inside task functions, mainly to update the result and be able to - resume a task at a specific step after a restart of synapse. + """Update some task-associated values. This is exposed publicly so it can + be used inside task functions, mainly to update the result or resume + a task at a specific step after a restart of synapse. It can also be used to stage a task, by setting the `status` to `SCHEDULED` with a new timestamp. - The `status` can only be set to `ACTIVE` or `SCHEDULED`, `COMPLETE` and `FAILED` - are terminal status and can only be set by returning it in the function. + The `status` can only be set to `ACTIVE` or `SCHEDULED`. `COMPLETE` and `FAILED` + are terminal statuses and can only be set by returning them from the function. Args: id: the id of the task to update @@ -223,6 +236,12 @@ class TaskScheduler: status: the new `TaskStatus` of the task result: the new result of the task error: the new error of the task + + Returns: + True if the update was successful, False otherwise. + + Raises: + Exception: If a status other than `ACTIVE` or `SCHEDULED` was passed. """ if status == TaskStatus.COMPLETE or status == TaskStatus.FAILED: raise Exception( @@ -260,9 +279,9 @@ class TaskScheduler: max_timestamp: Optional[int] = None, limit: Optional[int] = None, ) -> List[ScheduledTask]: - """Get a list of tasks. Returns all the tasks if no args is provided. + """Get a list of tasks. Returns all the tasks if no args are provided. - If an arg is `None` all tasks matching the other args will be selected. + If an arg is `None`, all tasks matching the other args will be selected. If an arg is an empty list, the corresponding value of the task needs to be `None` to be selected. @@ -274,8 +293,8 @@ class TaskScheduler: a timestamp inferior to the specified one limit: Only return `limit` number of rows if set. - Returns - A list of `ScheduledTask`, ordered by increasing timestamps + Returns: + A list of `ScheduledTask`, ordered by increasing timestamps. """ return await self._store.get_scheduled_tasks( actions=actions, @@ -300,23 +319,13 @@ class TaskScheduler: raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") await self._store.delete_scheduled_task(id) - def launch_task_by_id(self, id: str) -> None: - """Try launching the task with the given ID.""" - # Don't bother trying to launch new tasks if we're already at capacity. - if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: - return - - run_as_background_process("launch_task_by_id", self._launch_task_by_id, id) + def on_new_task(self, task_id: str) -> None: + """Handle a notification that a new ready-to-run task has been added to the queue""" + # Just run the scheduler + self._launch_scheduled_tasks() - async def _launch_task_by_id(self, id: str) -> None: - """Helper async function for `launch_task_by_id`.""" - task = await self.get_task(id) - if task: - await self._launch_task(task) - - @wrap_as_background_process("launch_scheduled_tasks") - async def _launch_scheduled_tasks(self) -> None: - """Retrieve and launch scheduled tasks that should be running at that time.""" + def _launch_scheduled_tasks(self) -> None: + """Retrieve and launch scheduled tasks that should be running at this time.""" # Don't bother trying to launch new tasks if we're already at capacity. if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: return @@ -326,20 +335,26 @@ class TaskScheduler: self._launching_new_tasks = True - try: - for task in await self.get_tasks( - statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS - ): - await self._launch_task(task) - for task in await self.get_tasks( - statuses=[TaskStatus.SCHEDULED], - max_timestamp=self._clock.time_msec(), - limit=self.MAX_CONCURRENT_RUNNING_TASKS, - ): - await self._launch_task(task) - - finally: - self._launching_new_tasks = False + async def inner() -> None: + try: + for task in await self.get_tasks( + statuses=[TaskStatus.ACTIVE], + limit=self.MAX_CONCURRENT_RUNNING_TASKS, + ): + # _launch_task will ignore tasks that we're already running, and + # will also do nothing if we're already at the maximum capacity. + await self._launch_task(task) + for task in await self.get_tasks( + statuses=[TaskStatus.SCHEDULED], + max_timestamp=self._clock.time_msec(), + limit=self.MAX_CONCURRENT_RUNNING_TASKS, + ): + await self._launch_task(task) + + finally: + self._launching_new_tasks = False + + run_as_background_process("launch_scheduled_tasks", inner) @wrap_as_background_process("clean_scheduled_tasks") async def _clean_scheduled_tasks(self) -> None: diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py deleted file mode 100644
index 5c9193e8a9..0000000000 --- a/synapse/util/threepids.py +++ /dev/null
@@ -1,123 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -import logging -import re -import typing - -if typing.TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -# it's unclear what the maximum length of an email address is. RFC3696 (as corrected -# by errata) says: -# the upper limit on address lengths should normally be considered to be 254. -# -# In practice, mail servers appear to be more tolerant and allow 400 characters -# or so. Let's allow 500, which should be plenty for everyone. -# -MAX_EMAIL_ADDRESS_LENGTH = 500 - - -async def check_3pid_allowed( - hs: "HomeServer", - medium: str, - address: str, - registration: bool = False, -) -> bool: - """Checks whether a given format of 3PID is allowed to be used on this HS - - Args: - hs: server - medium: 3pid medium - e.g. email, msisdn - address: address within that medium (e.g. "wotan@matrix.org") - msisdns need to first have been canonicalised - registration: whether we want to bind the 3PID as part of registering a new user. - - Returns: - whether the 3PID medium/address is allowed to be added to this HS - """ - if not await hs.get_password_auth_provider().is_3pid_allowed( - medium, address, registration - ): - return False - - if hs.config.registration.allowed_local_3pids: - for constraint in hs.config.registration.allowed_local_3pids: - logger.debug( - "Checking 3PID %s (%s) against %s (%s)", - address, - medium, - constraint["pattern"], - constraint["medium"], - ) - if medium == constraint["medium"] and re.match( - constraint["pattern"], address - ): - return True - else: - return True - - return False - - -def canonicalise_email(address: str) -> str: - """'Canonicalise' email address - Case folding of local part of email address and lowercase domain part - See MSC2265, https://github.com/matrix-org/matrix-doc/pull/2265 - - Args: - address: email address to be canonicalised - Returns: - The canonical form of the email address - Raises: - ValueError if the address could not be parsed. - """ - - address = address.strip() - - parts = address.split("@") - if len(parts) != 2: - logger.debug("Couldn't parse email address %s", address) - raise ValueError("Unable to parse email address") - - return parts[0].casefold() + "@" + parts[1].lower() - - -def validate_email(address: str) -> str: - """Does some basic validation on an email address. - - Returns the canonicalised email, as returned by `canonicalise_email`. - - Raises a ValueError if the email is invalid. - """ - # First we try canonicalising in case that fails - address = canonicalise_email(address) - - # Email addresses have to be at least 3 characters. - if len(address) < 3: - raise ValueError("Unable to parse email address") - - if len(address) > MAX_EMAIL_ADDRESS_LENGTH: - raise ValueError("Unable to parse email address") - - return address diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 44b109bdfd..95eb1d7185 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py
@@ -47,7 +47,6 @@ class WheelTimer(Generic[T]): """ self.bucket_size: int = bucket_size self.entries: List[_Entry[T]] = [] - self.current_tick: int = 0 def insert(self, now: int, obj: T, then: int) -> None: """Inserts object into timer. @@ -78,11 +77,10 @@ class WheelTimer(Generic[T]): self.entries[max(min_key, then_key) - min_key].elements.add(obj) return - next_key = now_key + 1 if self.entries: - last_key = self.entries[-1].end_key + last_key = self.entries[-1].end_key + 1 else: - last_key = next_key + last_key = now_key + 1 # Handle the case when `then` is in the past and `entries` is empty. then_key = max(last_key, then_key) diff --git a/synapse/visibility.py b/synapse/visibility.py
index 128413c8aa..dc7b6e4065 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py
@@ -27,7 +27,6 @@ from typing import ( Final, FrozenSet, List, - Mapping, Optional, Sequence, Set, @@ -48,6 +47,7 @@ from synapse.events.utils import clone_event, prune_event from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore +from synapse.synapse_rust.events import event_visible_to_server from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id from synapse.types.state import StateFilter from synapse.util import Clock @@ -135,9 +135,9 @@ async def filter_events_for_client( retention_policies: Dict[str, RetentionPolicy] = {} for room_id in room_ids: - retention_policies[room_id] = ( - await storage.main.get_retention_policy_for_room(room_id) - ) + retention_policies[ + room_id + ] = await storage.main.get_retention_policy_for_room(room_id) def allowed(event: EventBase) -> Optional[EventBase]: state_after_event = event_id_to_state.get(event.event_id) @@ -628,17 +628,6 @@ async def filter_events_for_server( """Filter a list of events based on whether the target server is allowed to see them. - For a fully stated room, the target server is allowed to see an event E if: - - the state at E has world readable or shared history vis, OR - - the state at E says that the target server is in the room. - - For a partially stated room, the target server is allowed to see E if: - - E was created by this homeserver, AND: - - the partial state at E has world readable or shared history vis, OR - - the partial state at E says that the target server is in the room. - - TODO: state before or state after? - Args: storage target_server_name @@ -655,35 +644,6 @@ async def filter_events_for_server( The filtered events. """ - def is_sender_erased(event: EventBase, erased_senders: Mapping[str, bool]) -> bool: - if erased_senders and erased_senders[event.sender]: - logger.info("Sender of %s has been erased, redacting", event.event_id) - return True - return False - - def check_event_is_visible( - visibility: str, memberships: StateMap[EventBase] - ) -> bool: - if visibility not in (HistoryVisibility.INVITED, HistoryVisibility.JOINED): - return True - - # We now loop through all membership events looking for - # membership states for the requesting server to determine - # if the server is either in the room or has been invited - # into the room. - for ev in memberships.values(): - assert get_domain_from_id(ev.state_key) == target_server_name - - memtype = ev.membership - if memtype == Membership.JOIN: - return True - elif memtype == Membership.INVITE: - if visibility == HistoryVisibility.INVITED: - return True - - # server has no users in the room: redact - return False - if filter_out_erased_senders: erased_senders = await storage.main.are_users_erased(e.sender for e in events) else: @@ -726,20 +686,16 @@ async def filter_events_for_server( target_server_name, ) - def include_event_in_output(e: EventBase) -> bool: - erased = is_sender_erased(e, erased_senders) - visible = check_event_is_visible( - event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {}) - ) - - if e.event_id in partial_state_invisible_event_ids: - visible = False - - return visible and not erased - to_return = [] for e in events: - if include_event_in_output(e): + if event_visible_to_server( + sender=e.sender, + target_server_name=target_server_name, + history_visibility=event_to_history_vis[e.event_id], + erased_senders=erased_senders, + partial_state_invisible=e.event_id in partial_state_invisible_event_ids, + memberships=list(event_to_memberships.get(e.event_id, {}).values()), + ): to_return.append(e) elif redact: to_return.append(prune_event(e)) @@ -796,7 +752,7 @@ async def _event_to_history_vis( async def _event_to_memberships( storage: StorageControllers, events: Collection[EventBase], server_name: str -) -> Dict[str, StateMap[EventBase]]: +) -> Dict[str, StateMap[Tuple[str, str]]]: """Get the remote membership list at each of the given events Returns a map from event id to state map, which will contain only membership events @@ -849,7 +805,7 @@ async def _event_to_memberships( return { e_id: { - key: event_map[inner_e_id] + key: (event_map[inner_e_id].state_key, event_map[inner_e_id].membership) for key, inner_e_id in key_to_eid.items() if inner_e_id in event_map }