summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py1
-rw-r--r--tests/api/test_auth.py5
-rw-r--r--tests/api/test_errors.py9
-rw-r--r--tests/api/test_filtering.py3
-rw-r--r--tests/api/test_ratelimiting.py5
-rw-r--r--tests/app/test_homeserver_start.py1
-rw-r--r--tests/appservice/__init__.py1
-rw-r--r--tests/appservice/test_api.py1
-rw-r--r--tests/appservice/test_appservice.py1
-rw-r--r--tests/appservice/test_scheduler.py1
-rw-r--r--tests/config/__init__.py1
-rw-r--r--tests/config/test___main__.py1
-rw-r--r--tests/config/test_appservice.py1
-rw-r--r--tests/config/test_background_update.py1
-rw-r--r--tests/config/test_base.py1
-rw-r--r--tests/config/test_cache.py1
-rw-r--r--tests/config/test_generate.py1
-rw-r--r--tests/config/test_load.py2
-rw-r--r--tests/config/test_oauth_delegation.py1
-rw-r--r--tests/config/test_ratelimiting.py1
-rw-r--r--tests/config/test_registration_config.py1
-rw-r--r--tests/config/test_room_directory.py37
-rw-r--r--tests/config/test_tls.py1
-rw-r--r--tests/config/test_util.py1
-rw-r--r--tests/config/test_workers.py1
-rw-r--r--tests/config/utils.py1
-rw-r--r--tests/crypto/__init__.py1
-rw-r--r--tests/crypto/test_event_signing.py1
-rw-r--r--tests/crypto/test_keyring.py1
-rw-r--r--tests/events/test_auto_accept_invites.py657
-rw-r--r--tests/events/test_presence_router.py18
-rw-r--r--tests/events/test_snapshot.py1
-rw-r--r--tests/events/test_utils.py28
-rw-r--r--tests/federation/test_complexity.py1
-rw-r--r--tests/federation/test_federation_client.py1
-rw-r--r--tests/federation/test_federation_media.py258
-rw-r--r--tests/federation/test_federation_sender.py119
-rw-r--r--tests/federation/test_federation_server.py18
-rw-r--r--tests/federation/transport/server/__init__.py1
-rw-r--r--tests/federation/transport/server/test__base.py8
-rw-r--r--tests/federation/transport/test_client.py1
-rw-r--r--tests/federation/transport/test_knocking.py1
-rw-r--r--tests/federation/transport/test_server.py10
-rw-r--r--tests/handlers/test_admin.py1
-rw-r--r--tests/handlers/test_appservice.py1
-rw-r--r--tests/handlers/test_auth.py1
-rw-r--r--tests/handlers/test_cas.py1
-rw-r--r--tests/handlers/test_deactivate_account.py147
-rw-r--r--tests/handlers/test_device.py2
-rw-r--r--tests/handlers/test_directory.py53
-rw-r--r--tests/handlers/test_e2e_keys.py111
-rw-r--r--tests/handlers/test_e2e_room_keys.py2
-rw-r--r--tests/handlers/test_federation.py2
-rw-r--r--tests/handlers/test_federation_event.py1
-rw-r--r--tests/handlers/test_message.py41
-rw-r--r--tests/handlers/test_oauth_delegation.py3
-rw-r--r--tests/handlers/test_oidc.py1
-rw-r--r--tests/handlers/test_password_providers.py1
-rw-r--r--tests/handlers/test_presence.py1
-rw-r--r--tests/handlers/test_profile.py1
-rw-r--r--tests/handlers/test_receipts.py1
-rw-r--r--tests/handlers/test_register.py1
-rw-r--r--tests/handlers/test_room_member.py25
-rw-r--r--tests/handlers/test_room_summary.py1
-rw-r--r--tests/handlers/test_saml.py1
-rw-r--r--tests/handlers/test_send_email.py1
-rw-r--r--tests/handlers/test_sliding_sync.py4567
-rw-r--r--tests/handlers/test_sync.py834
-rw-r--r--tests/handlers/test_typing.py54
-rw-r--r--tests/handlers/test_user_directory.py39
-rw-r--r--tests/handlers/test_worker_lock.py24
-rw-r--r--tests/http/federation/test_srv_resolver.py1
-rw-r--r--tests/http/server/__init__.py1
-rw-r--r--tests/http/server/_base.py1
-rw-r--r--tests/http/test_client.py144
-rw-r--r--tests/http/test_proxy.py1
-rw-r--r--tests/http/test_proxyagent.py1
-rw-r--r--tests/http/test_servlet.py1
-rw-r--r--tests/http/test_simple_client.py1
-rw-r--r--tests/http/test_site.py1
-rw-r--r--tests/logging/__init__.py1
-rw-r--r--tests/logging/test_opentracing.py1
-rw-r--r--tests/logging/test_remote_handler.py1
-rw-r--r--tests/logging/test_terse_json.py1
-rw-r--r--tests/media/__init__.py1
-rw-r--r--tests/media/test_filepath.py1
-rw-r--r--tests/media/test_html_preview.py1
-rw-r--r--tests/media/test_media_retention.py1
-rw-r--r--tests/media/test_media_storage.py410
-rw-r--r--tests/media/test_oembed.py1
-rw-r--r--tests/media/test_url_previewer.py1
-rw-r--r--tests/metrics/test_metrics.py1
-rw-r--r--tests/module_api/test_account_data_manager.py1
-rw-r--r--tests/module_api/test_api.py3
-rw-r--r--tests/module_api/test_event_unsigned_addition.py1
-rw-r--r--tests/push/test_bulk_push_rule_evaluator.py1
-rw-r--r--tests/push/test_email.py37
-rw-r--r--tests/push/test_http.py82
-rw-r--r--tests/push/test_presentable_names.py1
-rw-r--r--tests/push/test_push_rule_evaluator.py1
-rw-r--r--tests/replication/__init__.py1
-rw-r--r--tests/replication/_base.py6
-rw-r--r--tests/replication/http/__init__.py1
-rw-r--r--tests/replication/http/test__base.py1
-rw-r--r--tests/replication/storage/__init__.py1
-rw-r--r--tests/replication/storage/_base.py1
-rw-r--r--tests/replication/storage/test_events.py128
-rw-r--r--tests/replication/tcp/streams/test_account_data.py1
-rw-r--r--tests/replication/tcp/streams/test_partial_state.py1
-rw-r--r--tests/replication/tcp/streams/test_to_device.py1
-rw-r--r--tests/replication/tcp/streams/test_typing.py1
-rw-r--r--tests/replication/tcp/test_commands.py1
-rw-r--r--tests/replication/tcp/test_handler.py1
-rw-r--r--tests/replication/test_auth.py1
-rw-r--r--tests/replication/test_client_reader_shard.py1
-rw-r--r--tests/replication/test_federation_ack.py1
-rw-r--r--tests/replication/test_federation_sender_shard.py1
-rw-r--r--tests/replication/test_module_cache_invalidation.py1
-rw-r--r--tests/replication/test_multi_media_repo.py235
-rw-r--r--tests/replication/test_pusher_shard.py1
-rw-r--r--tests/replication/test_sharded_event_persister.py1
-rw-r--r--tests/replication/test_sharded_receipts.py1
-rw-r--r--tests/rest/__init__.py1
-rw-r--r--tests/rest/admin/test_admin.py19
-rw-r--r--tests/rest/admin/test_background_updates.py1
-rw-r--r--tests/rest/admin/test_device.py1
-rw-r--r--tests/rest/admin/test_event_reports.py7
-rw-r--r--tests/rest/admin/test_federation.py68
-rw-r--r--tests/rest/admin/test_jwks.py1
-rw-r--r--tests/rest/admin/test_media.py7
-rw-r--r--tests/rest/admin/test_registration_tokens.py1
-rw-r--r--tests/rest/admin/test_room.py139
-rw-r--r--tests/rest/admin/test_server_notice.py1
-rw-r--r--tests/rest/admin/test_statistics.py2
-rw-r--r--tests/rest/admin/test_user.py141
-rw-r--r--tests/rest/admin/test_username_available.py1
-rw-r--r--tests/rest/client/__init__.py1
-rw-r--r--tests/rest/client/sliding_sync/__init__.py13
-rw-r--r--tests/rest/client/sliding_sync/test_connection_tracking.py453
-rw-r--r--tests/rest/client/sliding_sync/test_extension_account_data.py495
-rw-r--r--tests/rest/client/sliding_sync/test_extension_e2ee.py441
-rw-r--r--tests/rest/client/sliding_sync/test_extension_receipts.py679
-rw-r--r--tests/rest/client/sliding_sync/test_extension_to_device.py278
-rw-r--r--tests/rest/client/sliding_sync/test_extension_typing.py482
-rw-r--r--tests/rest/client/sliding_sync/test_extensions.py283
-rw-r--r--tests/rest/client/sliding_sync/test_room_subscriptions.py285
-rw-r--r--tests/rest/client/sliding_sync/test_rooms_invites.py510
-rw-r--r--tests/rest/client/sliding_sync/test_rooms_meta.py710
-rw-r--r--tests/rest/client/sliding_sync/test_rooms_required_state.py707
-rw-r--r--tests/rest/client/sliding_sync/test_rooms_timeline.py575
-rw-r--r--tests/rest/client/sliding_sync/test_sliding_sync.py974
-rw-r--r--tests/rest/client/test_account.py29
-rw-r--r--tests/rest/client/test_account_data.py1
-rw-r--r--tests/rest/client/test_auth.py1
-rw-r--r--tests/rest/client/test_devices.py145
-rw-r--r--tests/rest/client/test_events.py1
-rw-r--r--tests/rest/client/test_filter.py3
-rw-r--r--tests/rest/client/test_keys.py66
-rw-r--r--tests/rest/client/test_login.py213
-rw-r--r--tests/rest/client/test_login_token_request.py1
-rw-r--r--tests/rest/client/test_media.py2677
-rw-r--r--tests/rest/client/test_models.py3
-rw-r--r--tests/rest/client/test_mutual_rooms.py1
-rw-r--r--tests/rest/client/test_notifications.py172
-rw-r--r--tests/rest/client/test_password_policy.py1
-rw-r--r--tests/rest/client/test_power_levels.py1
-rw-r--r--tests/rest/client/test_profile.py1
-rw-r--r--tests/rest/client/test_push_rule_attrs.py1
-rw-r--r--tests/rest/client/test_read_marker.py9
-rw-r--r--tests/rest/client/test_receipts.py1
-rw-r--r--tests/rest/client/test_redactions.py1
-rw-r--r--tests/rest/client/test_register.py11
-rw-r--r--tests/rest/client/test_relations.py10
-rw-r--r--tests/rest/client/test_rendezvous.py436
-rw-r--r--tests/rest/client/test_report_event.py (renamed from tests/rest/client/test_reporting.py)94
-rw-r--r--tests/rest/client/test_retention.py6
-rw-r--r--tests/rest/client/test_rooms.py286
-rw-r--r--tests/rest/client/test_sendtodevice.py72
-rw-r--r--tests/rest/client/test_shadow_banned.py1
-rw-r--r--tests/rest/client/test_sync.py402
-rw-r--r--tests/rest/client/test_third_party_rules.py1
-rw-r--r--tests/rest/client/test_transactions.py1
-rw-r--r--tests/rest/client/test_typing.py1
-rw-r--r--tests/rest/client/test_upgrade_room.py1
-rw-r--r--tests/rest/client/utils.py41
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py1
-rw-r--r--tests/rest/media/test_domain_blocking.py15
-rw-r--r--tests/rest/media/test_url_preview.py1
-rw-r--r--tests/rest/synapse/__init__.py12
-rw-r--r--tests/rest/synapse/client/__init__.py12
-rw-r--r--tests/rest/synapse/client/test_federation_whitelist.py119
-rw-r--r--tests/rest/test_health.py1
-rw-r--r--tests/server.py158
-rw-r--r--tests/storage/databases/__init__.py1
-rw-r--r--tests/storage/databases/main/__init__.py1
-rw-r--r--tests/storage/databases/main/test_cache.py1
-rw-r--r--tests/storage/databases/main/test_deviceinbox.py1
-rw-r--r--tests/storage/databases/main/test_end_to_end_keys.py1
-rw-r--r--tests/storage/databases/main/test_events_worker.py1
-rw-r--r--tests/storage/databases/main/test_lock.py1
-rw-r--r--tests/storage/databases/main/test_receipts.py1
-rw-r--r--tests/storage/databases/main/test_room.py1
-rw-r--r--tests/storage/test__base.py1
-rw-r--r--tests/storage/test_account_data.py1
-rw-r--r--tests/storage/test_appservice.py1
-rw-r--r--tests/storage/test_background_update.py1
-rw-r--r--tests/storage/test_base.py1
-rw-r--r--tests/storage/test_cleanup_extrems.py19
-rw-r--r--tests/storage/test_client_ips.py1
-rw-r--r--tests/storage/test_database.py1
-rw-r--r--tests/storage/test_devices.py9
-rw-r--r--tests/storage/test_directory.py1
-rw-r--r--tests/storage/test_e2e_room_keys.py1
-rw-r--r--tests/storage/test_end_to_end_keys.py1
-rw-r--r--tests/storage/test_event_chain.py115
-rw-r--r--tests/storage/test_event_federation.py44
-rw-r--r--tests/storage/test_event_metrics.py1
-rw-r--r--tests/storage/test_event_push_actions.py1
-rw-r--r--tests/storage/test_events.py1
-rw-r--r--tests/storage/test_id_generators.py765
-rw-r--r--tests/storage/test_main.py1
-rw-r--r--tests/storage/test_profile.py1
-rw-r--r--tests/storage/test_receipts.py1
-rw-r--r--tests/storage/test_redaction.py1
-rw-r--r--tests/storage/test_registration.py3
-rw-r--r--tests/storage/test_relations.py1
-rw-r--r--tests/storage/test_rollback_worker.py1
-rw-r--r--tests/storage/test_room.py1
-rw-r--r--tests/storage/test_room_search.py30
-rw-r--r--tests/storage/test_roommember.py405
-rw-r--r--tests/storage/test_state.py1
-rw-r--r--tests/storage/test_stream.py1193
-rw-r--r--tests/storage/test_txn_limit.py1
-rw-r--r--tests/storage/test_unsafe_locale.py1
-rw-r--r--tests/storage/test_user_directory.py5
-rw-r--r--tests/storage/test_user_filters.py1
-rw-r--r--tests/storage/util/__init__.py1
-rw-r--r--tests/storage/util/test_partial_state_events_tracker.py1
-rw-r--r--tests/test_distributor.py1
-rw-r--r--tests/test_event_auth.py1
-rw-r--r--tests/test_federation.py1
-rw-r--r--tests/test_phone_home.py1
-rw-r--r--tests/test_server.py9
-rw-r--r--tests/test_state.py1
-rw-r--r--tests/test_test_utils.py1
-rw-r--r--tests/test_types.py72
-rw-r--r--tests/test_utils/__init__.py1
-rw-r--r--tests/test_utils/event_injection.py13
-rw-r--r--tests/test_utils/html_parsers.py1
-rw-r--r--tests/test_utils/oidc.py1
-rw-r--r--tests/test_visibility.py316
-rw-r--r--tests/unittest.py67
-rw-r--r--tests/util/__init__.py1
-rw-r--r--tests/util/caches/__init__.py1
-rw-r--r--tests/util/caches/test_cached_call.py1
-rw-r--r--tests/util/caches/test_deferred_cache.py1
-rw-r--r--tests/util/caches/test_descriptors.py1
-rw-r--r--tests/util/caches/test_response_cache.py1
-rw-r--r--tests/util/test_batching_queue.py1
-rw-r--r--tests/util/test_check_dependencies.py4
-rw-r--r--tests/util/test_dict_cache.py1
-rw-r--r--tests/util/test_expiring_cache.py1
-rw-r--r--tests/util/test_itertools.py1
-rw-r--r--tests/util/test_linearizer.py4
-rw-r--r--tests/util/test_logcontext.py1
-rw-r--r--tests/util/test_lrucache.py32
-rw-r--r--tests/util/test_macaroons.py1
-rw-r--r--tests/util/test_ratelimitutils.py1
-rw-r--r--tests/util/test_retryutils.py1
-rw-r--r--tests/util/test_rwlock.py1
-rw-r--r--tests/util/test_stream_change_cache.py21
-rw-r--r--tests/util/test_stringutils.py1
-rw-r--r--tests/util/test_task_scheduler.py1
-rw-r--r--tests/util/test_threepids.py1
-rw-r--r--tests/util/test_treecache.py1
-rw-r--r--tests/util/test_wheel_timer.py1
-rw-r--r--tests/utils.py49
277 files changed, 1363 insertions, 22374 deletions
diff --git a/tests/__init__.py b/tests/__init__.py

index 4c8633b445..775bec0227 100644 --- a/tests/__init__.py +++ b/tests/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index bd229cf7e9..dbcb13c0a5 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015 - 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -128,7 +127,7 @@ class AuthTestCase(unittest.HomeserverTestCase): token="foobar", url="a_url", sender=self.test_user, - ip_range_whitelist=IPSet(["192.168.0.0/16"]), + ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = AsyncMock(return_value=None) @@ -147,7 +146,7 @@ class AuthTestCase(unittest.HomeserverTestCase): token="foobar", url="a_url", sender=self.test_user, - ip_range_whitelist=IPSet(["192.168.0.0/16"]), + ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = AsyncMock(return_value=None) diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py
index efa3addf00..5e324595a4 100644 --- a/tests/api/test_errors.py +++ b/tests/api/test_errors.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -33,14 +32,18 @@ class LimitExceededErrorTestCase(unittest.TestCase): self.assertIn("needle", err.debug_context) self.assertNotIn("needle", serialised) + # Create a sub-class to avoid mutating the class-level property. + class LimitExceededErrorHeaders(LimitExceededError): + include_retry_after_header = True + def test_limit_exceeded_header(self) -> None: - err = LimitExceededError(limiter_name="test", retry_after_ms=100) + err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100) self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100) assert err.headers is not None self.assertEqual(err.headers.get("Retry-After"), "1") def test_limit_exceeded_rounding(self) -> None: - err = LimitExceededError(limiter_name="test", retry_after_ms=3001) + err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001) self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001) assert err.headers is not None self.assertEqual(err.headers.get("Retry-After"), "4") diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 743c52d969..73678b6f33 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py
@@ -1,9 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. -# Copyright 2017 Vector Creations Ltd -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index a59e168db1..a24638c9ef 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py
@@ -116,9 +116,8 @@ class TestRatelimiter(unittest.HomeserverTestCase): # Should raise with self.assertRaises(LimitExceededError) as context: self.get_success_or_raise( - limiter.ratelimit(None, key="test_id", _time_now_s=5), by=0.5 + limiter.ratelimit(None, key="test_id", _time_now_s=5) ) - self.assertEqual(context.exception.retry_after_ms, 5000) # Shouldn't raise @@ -193,7 +192,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): # Second attempt, 1s later, will fail with self.assertRaises(LimitExceededError) as context: self.get_success_or_raise( - limiter.ratelimit(None, key=("test_id",), _time_now_s=1), by=0.5 + limiter.ratelimit(None, key=("test_id",), _time_now_s=1) ) self.assertEqual(context.exception.retry_after_ms, 9000) diff --git a/tests/app/test_homeserver_start.py b/tests/app/test_homeserver_start.py
index 9dc20800b2..576323e391 100644 --- a/tests/app/test_homeserver_start.py +++ b/tests/app/test_homeserver_start.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/appservice/__init__.py b/tests/appservice/__init__.py
index 6a72062b0c..3d833a2e44 100644 --- a/tests/appservice/__init__.py +++ b/tests/appservice/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 0f19736540..344be6e1de 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 3fa4426638..fd7d089edf 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index a1c7ccdd0b..b6ecacccb5 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/__init__.py b/tests/config/__init__.py
index 587ee42067..3d833a2e44 100644 --- a/tests/config/__init__.py +++ b/tests/config/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test___main__.py b/tests/config/test___main__.py
index ca8884823c..7af635f5ec 100644 --- a/tests/config/test___main__.py +++ b/tests/config/test___main__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_appservice.py b/tests/config/test_appservice.py
index e3021b59d8..3a2e268d73 100644 --- a/tests/config/test_appservice.py +++ b/tests/config/test_appservice.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_background_update.py b/tests/config/test_background_update.py
index 678a068481..984365d595 100644 --- a/tests/config/test_background_update.py +++ b/tests/config/test_background_update.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_base.py b/tests/config/test_base.py
index edb91bc4d9..41ade50714 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index 631263b5ca..2108b40ff4 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 5ab96e16e1..95b13defba 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 479d2aab91..ef9976fb8d 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py
@@ -1,8 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py
index 713bddeb90..79c10b10a6 100644 --- a/tests/config/test_oauth_delegation.py +++ b/tests/config/test_oauth_delegation.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py
index 0d52a96858..8267089298 100644 --- a/tests/config/test_ratelimiting.py +++ b/tests/config/test_ratelimiting.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py
index 7fd6df2f93..16e3e13dd0 100644 --- a/tests/config/test_registration_config.py +++ b/tests/config/test_registration_config.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
index e25f7787f4..574697cfd9 100644 --- a/tests/config/test_room_directory.py +++ b/tests/config/test_room_directory.py
@@ -17,31 +17,15 @@ # [This file includes modifications made by New Vector Limited] # # -import yaml -from twisted.test.proto_helpers import MemoryReactor +import yaml -import synapse.rest.admin -import synapse.rest.client.login -import synapse.rest.client.room from synapse.config.room_directory import RoomDirectoryConfig -from synapse.server import HomeServer -from synapse.util import Clock from tests import unittest -from tests.unittest import override_config - -class RoomDirectoryConfigTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - - servlets = [ - synapse.rest.admin.register_servlets, - synapse.rest.client.login.register_servlets, - synapse.rest.client.room.register_servlets, - ] +class RoomDirectoryConfigTestCase(unittest.TestCase): def test_alias_creation_acl(self) -> None: config = yaml.safe_load( """ @@ -183,20 +167,3 @@ class RoomDirectoryConfigTestCase(unittest.HomeserverTestCase): aliases=["#unofficial_st:example.com", "#blah:example.com"], ) ) - - @override_config({"room_list_publication_rules": []}) - def test_room_creation_when_publishing_denied(self) -> None: - """ - Test that when room publishing is denied via the config that new rooms can - still be created and that the newly created room is not public. - """ - - user = self.register_user("alice", "pass") - token = self.login("alice", "pass") - room_id = self.helper.create_room_as(user, is_public=True, tok=token) - - res = self.get_success(self.store.get_room(room_id)) - assert res is not None - is_public, _ = res - - self.assertFalse(is_public) diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index 5c2cefaf5b..6cfdd181f0 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_util.py b/tests/config/test_util.py
index 64538a4628..8071850cce 100644 --- a/tests/config/test_util.py +++ b/tests/config/test_util.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/test_workers.py b/tests/config/test_workers.py
index 64c0285d01..7a97246064 100644 --- a/tests/config/test_workers.py +++ b/tests/config/test_workers.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/config/utils.py b/tests/config/utils.py
index 11140ff979..7842b88e72 100644 --- a/tests/config/utils.py +++ b/tests/config/utils.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/crypto/__init__.py b/tests/crypto/__init__.py
index fcd2134c89..3d833a2e44 100644 --- a/tests/crypto/__init__.py +++ b/tests/crypto/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index d7b9fb8bc6..f8369f75f7 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 3bfaf1c80d..c33b03d37d 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2017-2021 The Matrix.org Foundation C.I.C # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/events/test_auto_accept_invites.py b/tests/events/test_auto_accept_invites.py deleted file mode 100644
index 7fb4d4fa90..0000000000 --- a/tests/events/test_auto_accept_invites.py +++ /dev/null
@@ -1,657 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2021 The Matrix.org Foundation C.I.C -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# -import asyncio -from asyncio import Future -from http import HTTPStatus -from typing import Any, Awaitable, Dict, List, Optional, Tuple, TypeVar, cast -from unittest.mock import Mock - -import attr -from parameterized import parameterized - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError -from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig -from synapse.events.auto_accept_invites import InviteAutoAccepter -from synapse.federation.federation_base import event_from_pdu_json -from synapse.handlers.sync import JoinedSyncResult, SyncRequestKey, SyncVersion -from synapse.module_api import ModuleApi -from synapse.rest import admin -from synapse.rest.client import login, room -from synapse.server import HomeServer -from synapse.types import StreamToken, create_requester -from synapse.util import Clock - -from tests.handlers.test_sync import generate_sync_config -from tests.unittest import ( - FederatingHomeserverTestCase, - HomeserverTestCase, - TestCase, - override_config, -) - - -class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase): - """ - Integration test cases for auto-accepting invites. - """ - - servlets = [ - admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - hs = self.setup_test_homeserver() - self.handler = hs.get_federation_handler() - self.store = hs.get_datastores().main - return hs - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sync_handler = self.hs.get_sync_handler() - self.module_api = hs.get_module_api() - - @parameterized.expand( - [ - [False], - [True], - ] - ) - @override_config( - { - "auto_accept_invites": { - "enabled": True, - }, - } - ) - def test_auto_accept_invites(self, direct_room: bool) -> None: - """Test that a user automatically joins a room when invited, if the - module is enabled. - """ - # A local user who sends an invite - inviting_user_id = self.register_user("inviter", "pass") - inviting_user_tok = self.login("inviter", "pass") - - # A local user who receives an invite - invited_user_id = self.register_user("invitee", "pass") - self.login("invitee", "pass") - - # Create a room and send an invite to the other user - room_id = self.helper.create_room_as( - inviting_user_id, - is_public=False, - tok=inviting_user_tok, - ) - - self.helper.invite( - room_id, - inviting_user_id, - invited_user_id, - tok=inviting_user_tok, - extra_data={"is_direct": direct_room}, - ) - - # Check that the invite receiving user has automatically joined the room when syncing - join_updates, _ = sync_join(self, invited_user_id) - self.assertEqual(len(join_updates), 1) - - join_update: JoinedSyncResult = join_updates[0] - self.assertEqual(join_update.room_id, room_id) - - @override_config( - { - "auto_accept_invites": { - "enabled": False, - }, - } - ) - def test_module_not_enabled(self) -> None: - """Test that a user does not automatically join a room when invited, - if the module is not enabled. - """ - # A local user who sends an invite - inviting_user_id = self.register_user("inviter", "pass") - inviting_user_tok = self.login("inviter", "pass") - - # A local user who receives an invite - invited_user_id = self.register_user("invitee", "pass") - self.login("invitee", "pass") - - # Create a room and send an invite to the other user - room_id = self.helper.create_room_as( - inviting_user_id, is_public=False, tok=inviting_user_tok - ) - - self.helper.invite( - room_id, - inviting_user_id, - invited_user_id, - tok=inviting_user_tok, - ) - - # Check that the invite receiving user has not automatically joined the room when syncing - join_updates, _ = sync_join(self, invited_user_id) - self.assertEqual(len(join_updates), 0) - - @override_config( - { - "auto_accept_invites": { - "enabled": True, - }, - } - ) - def test_invite_from_remote_user(self) -> None: - """Test that an invite from a remote user results in the invited user - automatically joining the room. - """ - # A remote user who sends the invite - remote_server = "otherserver" - remote_user = "@otheruser:" + remote_server - - # A local user who creates the room - creator_user_id = self.register_user("creator", "pass") - creator_user_tok = self.login("creator", "pass") - - # A local user who receives an invite - invited_user_id = self.register_user("invitee", "pass") - self.login("invitee", "pass") - - room_id = self.helper.create_room_as( - room_creator=creator_user_id, tok=creator_user_tok - ) - room_version = self.get_success(self.store.get_room_version(room_id)) - - invite_event = event_from_pdu_json( - { - "type": EventTypes.Member, - "content": {"membership": "invite"}, - "room_id": room_id, - "sender": remote_user, - "state_key": invited_user_id, - "depth": 32, - "prev_events": [], - "auth_events": [], - "origin_server_ts": self.clock.time_msec(), - }, - room_version, - ) - self.get_success( - self.handler.on_invite_request( - remote_server, - invite_event, - invite_event.room_version, - ) - ) - - # Check that the invite receiving user has automatically joined the room when syncing - join_updates, _ = sync_join(self, invited_user_id) - self.assertEqual(len(join_updates), 1) - - join_update: JoinedSyncResult = join_updates[0] - self.assertEqual(join_update.room_id, room_id) - - @parameterized.expand( - [ - [False, False], - [True, True], - ] - ) - @override_config( - { - "auto_accept_invites": { - "enabled": True, - "only_for_direct_messages": True, - }, - } - ) - def test_accept_invite_direct_message( - self, - direct_room: bool, - expect_auto_join: bool, - ) -> None: - """Tests that, if the module is configured to only accept DM invites, invites to DM rooms are still - automatically accepted. Otherwise they are rejected. - """ - # A local user who sends an invite - inviting_user_id = self.register_user("inviter", "pass") - inviting_user_tok = self.login("inviter", "pass") - - # A local user who receives an invite - invited_user_id = self.register_user("invitee", "pass") - self.login("invitee", "pass") - - # Create a room and send an invite to the other user - room_id = self.helper.create_room_as( - inviting_user_id, - is_public=False, - tok=inviting_user_tok, - ) - - self.helper.invite( - room_id, - inviting_user_id, - invited_user_id, - tok=inviting_user_tok, - extra_data={"is_direct": direct_room}, - ) - - if expect_auto_join: - # Check that the invite receiving user has automatically joined the room when syncing - join_updates, _ = sync_join(self, invited_user_id) - self.assertEqual(len(join_updates), 1) - - join_update: JoinedSyncResult = join_updates[0] - self.assertEqual(join_update.room_id, room_id) - else: - # Check that the invite receiving user has not automatically joined the room when syncing - join_updates, _ = sync_join(self, invited_user_id) - self.assertEqual(len(join_updates), 0) - - @parameterized.expand( - [ - [False, True], - [True, False], - ] - ) - @override_config( - { - "auto_accept_invites": { - "enabled": True, - "only_from_local_users": True, - }, - } - ) - def test_accept_invite_local_user( - self, remote_inviter: bool, expect_auto_join: bool - ) -> None: - """Tests that, if the module is configured to only accept invites from local users, invites - from local users are still automatically accepted. Otherwise they are rejected. - """ - # A local user who sends an invite - creator_user_id = self.register_user("inviter", "pass") - creator_user_tok = self.login("inviter", "pass") - - # A local user who receives an invite - invited_user_id = self.register_user("invitee", "pass") - self.login("invitee", "pass") - - # Create a room and send an invite to the other user - room_id = self.helper.create_room_as( - creator_user_id, is_public=False, tok=creator_user_tok - ) - - if remote_inviter: - room_version = self.get_success(self.store.get_room_version(room_id)) - - # A remote user who sends the invite - remote_server = "otherserver" - remote_user = "@otheruser:" + remote_server - - invite_event = event_from_pdu_json( - { - "type": EventTypes.Member, - "content": {"membership": "invite"}, - "room_id": room_id, - "sender": remote_user, - "state_key": invited_user_id, - "depth": 32, - "prev_events": [], - "auth_events": [], - "origin_server_ts": self.clock.time_msec(), - }, - room_version, - ) - self.get_success( - self.handler.on_invite_request( - remote_server, - invite_event, - invite_event.room_version, - ) - ) - else: - self.helper.invite( - room_id, - creator_user_id, - invited_user_id, - tok=creator_user_tok, - ) - - if expect_auto_join: - # Check that the invite receiving user has automatically joined the room when syncing - join_updates, _ = sync_join(self, invited_user_id) - self.assertEqual(len(join_updates), 1) - - join_update: JoinedSyncResult = join_updates[0] - self.assertEqual(join_update.room_id, room_id) - else: - # Check that the invite receiving user has not automatically joined the room when syncing - join_updates, _ = sync_join(self, invited_user_id) - self.assertEqual(len(join_updates), 0) - - -_request_key = 0 - - -def generate_request_key() -> SyncRequestKey: - global _request_key - _request_key += 1 - return ("request_key", _request_key) - - -def sync_join( - testcase: HomeserverTestCase, - user_id: str, - since_token: Optional[StreamToken] = None, -) -> Tuple[List[JoinedSyncResult], StreamToken]: - """Perform a sync request for the given user and return the user join updates - they've received, as well as the next_batch token. - - This method assumes testcase.sync_handler points to the homeserver's sync handler. - - Args: - testcase: The testcase that is currently being run. - user_id: The ID of the user to generate a sync response for. - since_token: An optional token to indicate from at what point to sync from. - - Returns: - A tuple containing a list of join updates, and the sync response's - next_batch token. - """ - requester = create_requester(user_id) - sync_config = generate_sync_config(requester.user.to_string()) - sync_result = testcase.get_success( - testcase.hs.get_sync_handler().wait_for_sync_for_user( - requester, - sync_config, - SyncVersion.SYNC_V2, - generate_request_key(), - since_token, - ) - ) - - return sync_result.joined, sync_result.next_batch - - -class InviteAutoAccepterInternalTestCase(TestCase): - """ - Test cases which exercise the internals of the InviteAutoAccepter. - """ - - def setUp(self) -> None: - self.module = create_module() - self.user_id = "@peter:test" - self.invitee = "@lesley:test" - self.remote_invitee = "@thomas:remote" - - # We know our module API is a mock, but mypy doesn't. - self.mocked_update_membership: Mock = self.module._api.update_room_membership # type: ignore[assignment] - - async def test_accept_invite_with_failures(self) -> None: - """Tests that receiving an invite for a local user makes the module attempt to - make the invitee join the room. This test verifies that it works if the call to - update membership returns exceptions before successfully completing and returning an event. - """ - invite = MockEvent( - sender="@inviter:test", - state_key="@invitee:test", - type="m.room.member", - content={"membership": "invite"}, - ) - - join_event = MockEvent( - sender="someone", - state_key="someone", - type="m.room.member", - content={"membership": "join"}, - ) - # the first two calls raise an exception while the third call is successful - self.mocked_update_membership.side_effect = [ - SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"), - SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"), - make_awaitable(join_event), - ] - - # Stop mypy from complaining that we give on_new_event a MockEvent rather than an - # EventBase. - await self.module.on_new_event(event=invite) # type: ignore[arg-type] - - await self.retry_assertions( - self.mocked_update_membership, - 3, - sender=invite.state_key, - target=invite.state_key, - room_id=invite.room_id, - new_membership="join", - ) - - async def test_accept_invite_failures(self) -> None: - """Tests that receiving an invite for a local user makes the module attempt to - make the invitee join the room. This test verifies that if the update_membership call - fails consistently, _retry_make_join will break the loop after the set number of retries and - execution will continue. - """ - invite = MockEvent( - sender=self.user_id, - state_key=self.invitee, - type="m.room.member", - content={"membership": "invite"}, - ) - self.mocked_update_membership.side_effect = SynapseError( - HTTPStatus.FORBIDDEN, "Forbidden" - ) - - # Stop mypy from complaining that we give on_new_event a MockEvent rather than an - # EventBase. - await self.module.on_new_event(event=invite) # type: ignore[arg-type] - - await self.retry_assertions( - self.mocked_update_membership, - 5, - sender=invite.state_key, - target=invite.state_key, - room_id=invite.room_id, - new_membership="join", - ) - - async def test_not_state(self) -> None: - """Tests that receiving an invite that's not a state event does nothing.""" - invite = MockEvent( - sender=self.user_id, type="m.room.member", content={"membership": "invite"} - ) - - # Stop mypy from complaining that we give on_new_event a MockEvent rather than an - # EventBase. - await self.module.on_new_event(event=invite) # type: ignore[arg-type] - - self.mocked_update_membership.assert_not_called() - - async def test_not_invite(self) -> None: - """Tests that receiving a membership update that's not an invite does nothing.""" - invite = MockEvent( - sender=self.user_id, - state_key=self.user_id, - type="m.room.member", - content={"membership": "join"}, - ) - - # Stop mypy from complaining that we give on_new_event a MockEvent rather than an - # EventBase. - await self.module.on_new_event(event=invite) # type: ignore[arg-type] - - self.mocked_update_membership.assert_not_called() - - async def test_not_membership(self) -> None: - """Tests that receiving a state event that's not a membership update does - nothing. - """ - invite = MockEvent( - sender=self.user_id, - state_key=self.user_id, - type="org.matrix.test", - content={"foo": "bar"}, - ) - - # Stop mypy from complaining that we give on_new_event a MockEvent rather than an - # EventBase. - await self.module.on_new_event(event=invite) # type: ignore[arg-type] - - self.mocked_update_membership.assert_not_called() - - def test_config_parse(self) -> None: - """Tests that a correct configuration parses.""" - config = { - "auto_accept_invites": { - "enabled": True, - "only_for_direct_messages": True, - "only_from_local_users": True, - } - } - parsed_config = AutoAcceptInvitesConfig() - parsed_config.read_config(config) - - self.assertTrue(parsed_config.enabled) - self.assertTrue(parsed_config.accept_invites_only_for_direct_messages) - self.assertTrue(parsed_config.accept_invites_only_from_local_users) - - def test_runs_on_only_one_worker(self) -> None: - """ - Tests that the module only runs on the specified worker. - """ - # By default, we run on the main process... - main_module = create_module( - config_override={"auto_accept_invites": {"enabled": True}}, worker_name=None - ) - cast( - Mock, main_module._api.register_third_party_rules_callbacks - ).assert_called_once() - - # ...and not on other workers (like synchrotrons)... - sync_module = create_module(worker_name="synchrotron42") - cast( - Mock, sync_module._api.register_third_party_rules_callbacks - ).assert_not_called() - - # ...unless we configured them to be the designated worker. - specified_module = create_module( - config_override={ - "auto_accept_invites": { - "enabled": True, - "worker_to_run_on": "account_data1", - } - }, - worker_name="account_data1", - ) - cast( - Mock, specified_module._api.register_third_party_rules_callbacks - ).assert_called_once() - - async def retry_assertions( - self, mock: Mock, call_count: int, **kwargs: Any - ) -> None: - """ - This is a hacky way to ensure that the assertions are not called before the other coroutine - has a chance to call `update_room_membership`. It catches the exception caused by a failure, - and sleeps the thread before retrying, up until 5 tries. - - Args: - call_count: the number of times the mock should have been called - mock: the mocked function we want to assert on - kwargs: keyword arguments to assert that the mock was called with - """ - - i = 0 - while i < 5: - try: - # Check that the mocked method is called the expected amount of times and with the right - # arguments to attempt to make the user join the room. - mock.assert_called_with(**kwargs) - self.assertEqual(call_count, mock.call_count) - break - except AssertionError as e: - i += 1 - if i == 5: - # we've used up the tries, force the test to fail as we've already caught the exception - self.fail(e) - await asyncio.sleep(1) - - -@attr.s(auto_attribs=True) -class MockEvent: - """Mocks an event. Only exposes properties the module uses.""" - - sender: str - type: str - content: Dict[str, Any] - room_id: str = "!someroom" - state_key: Optional[str] = None - - def is_state(self) -> bool: - """Checks if the event is a state event by checking if it has a state key.""" - return self.state_key is not None - - @property - def membership(self) -> str: - """Extracts the membership from the event. Should only be called on an event - that's a membership event, and will raise a KeyError otherwise. - """ - membership: str = self.content["membership"] - return membership - - -T = TypeVar("T") -TV = TypeVar("TV") - - -async def make_awaitable(value: T) -> T: - return value - - -def make_multiple_awaitable(result: TV) -> Awaitable[TV]: - """ - Makes an awaitable, suitable for mocking an `async` function. - This uses Futures as they can be awaited multiple times so can be returned - to multiple callers. - """ - future: Future[TV] = Future() - future.set_result(result) - return future - - -def create_module( - config_override: Optional[Dict[str, Any]] = None, worker_name: Optional[str] = None -) -> InviteAutoAccepter: - # Create a mock based on the ModuleApi spec, but override some mocked functions - # because some capabilities are needed for running the tests. - module_api = Mock(spec=ModuleApi) - module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test" - module_api.worker_name = worker_name - module_api.sleep.return_value = make_multiple_awaitable(None) - - if config_override is None: - config_override = {} - - config = AutoAcceptInvitesConfig() - config.read_config(config_override) - - return InviteAutoAccepter(config, module_api) diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index e48983ddfe..ce66241763 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -36,7 +35,7 @@ from synapse.server import HomeServer from synapse.types import JsonDict, StreamToken, create_requester from synapse.util import Clock -from tests.handlers.test_sync import SyncRequestKey, SyncVersion, generate_sync_config +from tests.handlers.test_sync import generate_sync_config from tests.unittest import ( FederatingHomeserverTestCase, HomeserverTestCase, @@ -498,15 +497,6 @@ def send_presence_update( return channel.json_body -_request_key = 0 - - -def generate_request_key() -> SyncRequestKey: - global _request_key - _request_key += 1 - return ("request_key", _request_key) - - def sync_presence( testcase: HomeserverTestCase, user_id: str, @@ -530,11 +520,7 @@ def sync_presence( sync_config = generate_sync_config(requester.user.to_string()) sync_result = testcase.get_success( testcase.hs.get_sync_handler().wait_for_sync_for_user( - requester, - sync_config, - SyncVersion.SYNC_V2, - generate_request_key(), - since_token, + requester, sync_config, since_token ) ) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index f96bbe7705..61640c8c2f 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 30f8787758..bf0da95d12 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -32,7 +31,6 @@ from synapse.events.utils import ( PowerLevelsContent, SerializeEventConfig, _split_field, - clone_event, copy_and_fixup_power_levels_contents, maybe_upsert_event_field, prune_event, @@ -612,32 +610,6 @@ class PruneEventTestCase(stdlib_unittest.TestCase): ) -class CloneEventTestCase(stdlib_unittest.TestCase): - def test_unsigned_is_copied(self) -> None: - original = make_event_from_dict( - { - "type": "A", - "event_id": "$test:domain", - "unsigned": {"a": 1, "b": 2}, - }, - RoomVersions.V1, - {"txn_id": "txn"}, - ) - original.internal_metadata.stream_ordering = 1234 - self.assertEqual(original.internal_metadata.stream_ordering, 1234) - original.internal_metadata.instance_name = "worker1" - self.assertEqual(original.internal_metadata.instance_name, "worker1") - - cloned = clone_event(original) - cloned.unsigned["b"] = 3 - - self.assertEqual(original.unsigned, {"a": 1, "b": 2}) - self.assertEqual(cloned.unsigned, {"a": 1, "b": 3}) - self.assertEqual(cloned.internal_metadata.stream_ordering, 1234) - self.assertEqual(cloned.internal_metadata.instance_name, "worker1") - self.assertEqual(cloned.internal_metadata.txn_id, "txn") - - class SerializeEventTestCase(stdlib_unittest.TestCase): def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict: return serialize_event( diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 9bd97e5d4e..8a06695ff6 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 Matrix.org Foundation # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 585f3b798c..5313bb33a3 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 Matrix.org Federation C.I.C # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py deleted file mode 100644
index 0dcf20f5f5..0000000000 --- a/tests/federation/test_federation_media.py +++ /dev/null
@@ -1,258 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# -import io -import os -import shutil -import tempfile - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.media.filepath import MediaFilePaths -from synapse.media.media_storage import MediaStorage -from synapse.media.storage_provider import ( - FileStorageProviderBackend, - StorageProviderWrapper, -) -from synapse.server import HomeServer -from synapse.types import UserID -from synapse.util import Clock - -from tests import unittest -from tests.media.test_media_storage import small_png -from tests.test_utils import SMALL_PNG - - -class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase): - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - super().prepare(reactor, clock, hs) - self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") - self.addCleanup(shutil.rmtree, self.test_dir) - self.primary_base_path = os.path.join(self.test_dir, "primary") - self.secondary_base_path = os.path.join(self.test_dir, "secondary") - - hs.config.media.media_store_path = self.primary_base_path - - storage_providers = [ - StorageProviderWrapper( - FileStorageProviderBackend(hs, self.secondary_base_path), - store_local=True, - store_remote=False, - store_synchronous=True, - ) - ] - - self.filepaths = MediaFilePaths(self.primary_base_path) - self.media_storage = MediaStorage( - hs, self.primary_base_path, self.filepaths, storage_providers - ) - self.media_repo = hs.get_media_repository() - - def test_file_download(self) -> None: - content = io.BytesIO(b"file_to_stream") - content_uri = self.get_success( - self.media_repo.create_content( - "text/plain", - "test_upload", - content, - 46, - UserID.from_string("@user_id:whatever.org"), - ) - ) - # test with a text file - channel = self.make_signed_federation_request( - "GET", - f"/_matrix/federation/v1/media/download/{content_uri.media_id}", - ) - self.pump() - self.assertEqual(200, channel.code) - - content_type = channel.headers.getRawHeaders("content-type") - assert content_type is not None - assert "multipart/mixed" in content_type[0] - assert "boundary" in content_type[0] - - # extract boundary - boundary = content_type[0].split("boundary=")[1] - # split on boundary and check that json field and expected value exist - stripped = channel.text_body.split("\r\n" + "--" + boundary) - # TODO: the json object expected will change once MSC3911 is implemented, currently - # {} is returned for all requests as a placeholder (per MSC3196) - found_json = any( - "\r\nContent-Type: application/json\r\n\r\n{}" in field - for field in stripped - ) - self.assertTrue(found_json) - - # check that the text file and expected value exist - found_file = any( - "\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream" - in field - for field in stripped - ) - self.assertTrue(found_file) - - content = io.BytesIO(SMALL_PNG) - content_uri = self.get_success( - self.media_repo.create_content( - "image/png", - "test_png_upload", - content, - 67, - UserID.from_string("@user_id:whatever.org"), - ) - ) - # test with an image file - channel = self.make_signed_federation_request( - "GET", - f"/_matrix/federation/v1/media/download/{content_uri.media_id}", - ) - self.pump() - self.assertEqual(200, channel.code) - - content_type = channel.headers.getRawHeaders("content-type") - assert content_type is not None - assert "multipart/mixed" in content_type[0] - assert "boundary" in content_type[0] - - # extract boundary - boundary = content_type[0].split("boundary=")[1] - # split on boundary and check that json field and expected value exist - body = channel.result.get("body") - assert body is not None - stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8")) - found_json = any( - b"\r\nContent-Type: application/json\r\n\r\n{}" in field - for field in stripped_bytes - ) - self.assertTrue(found_json) - - # check that the png file exists and matches what was uploaded - found_file = any(SMALL_PNG in field for field in stripped_bytes) - self.assertTrue(found_file) - - -class FederationThumbnailTest(unittest.FederatingHomeserverTestCase): - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - super().prepare(reactor, clock, hs) - self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") - self.addCleanup(shutil.rmtree, self.test_dir) - self.primary_base_path = os.path.join(self.test_dir, "primary") - self.secondary_base_path = os.path.join(self.test_dir, "secondary") - - hs.config.media.media_store_path = self.primary_base_path - - storage_providers = [ - StorageProviderWrapper( - FileStorageProviderBackend(hs, self.secondary_base_path), - store_local=True, - store_remote=False, - store_synchronous=True, - ) - ] - - self.filepaths = MediaFilePaths(self.primary_base_path) - self.media_storage = MediaStorage( - hs, self.primary_base_path, self.filepaths, storage_providers - ) - self.media_repo = hs.get_media_repository() - - def test_thumbnail_download_scaled(self) -> None: - content = io.BytesIO(small_png.data) - content_uri = self.get_success( - self.media_repo.create_content( - "image/png", - "test_png_thumbnail", - content, - 67, - UserID.from_string("@user_id:whatever.org"), - ) - ) - # test with an image file - channel = self.make_signed_federation_request( - "GET", - f"/_matrix/federation/v1/media/thumbnail/{content_uri.media_id}?width=32&height=32&method=scale", - ) - self.pump() - self.assertEqual(200, channel.code) - - content_type = channel.headers.getRawHeaders("content-type") - assert content_type is not None - assert "multipart/mixed" in content_type[0] - assert "boundary" in content_type[0] - - # extract boundary - boundary = content_type[0].split("boundary=")[1] - # split on boundary and check that json field and expected value exist - body = channel.result.get("body") - assert body is not None - stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8")) - found_json = any( - b"\r\nContent-Type: application/json\r\n\r\n{}" in field - for field in stripped_bytes - ) - self.assertTrue(found_json) - - # check that the png file exists and matches the expected scaled bytes - found_file = any(small_png.expected_scaled in field for field in stripped_bytes) - self.assertTrue(found_file) - - def test_thumbnail_download_cropped(self) -> None: - content = io.BytesIO(small_png.data) - content_uri = self.get_success( - self.media_repo.create_content( - "image/png", - "test_png_thumbnail", - content, - 67, - UserID.from_string("@user_id:whatever.org"), - ) - ) - # test with an image file - channel = self.make_signed_federation_request( - "GET", - f"/_matrix/federation/v1/media/thumbnail/{content_uri.media_id}?width=32&height=32&method=crop", - ) - self.pump() - self.assertEqual(200, channel.code) - - content_type = channel.headers.getRawHeaders("content-type") - assert content_type is not None - assert "multipart/mixed" in content_type[0] - assert "boundary" in content_type[0] - - # extract boundary - boundary = content_type[0].split("boundary=")[1] - # split on boundary and check that json field and expected value exist - body = channel.result.get("body") - assert body is not None - stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8")) - found_json = any( - b"\r\nContent-Type: application/json\r\n\r\n{}" in field - for field in stripped_bytes - ) - self.assertTrue(found_json) - - # check that the png file exists and matches the expected cropped bytes - found_file = any( - small_png.expected_cropped in field for field in stripped_bytes - ) - self.assertTrue(found_file) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 6a8887fe74..9073afc70e 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py
@@ -27,8 +27,6 @@ from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms -from synapse.api.presence import UserPresenceState -from synapse.federation.sender.per_destination_queue import MAX_PRESENCE_STATES_PER_EDU from synapse.federation.units import Transaction from synapse.handlers.device import DeviceHandler from synapse.rest import admin @@ -268,123 +266,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ) -class FederationSenderPresenceTestCases(HomeserverTestCase): - """ - Test federation sending for presence updates. - """ - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.federation_transport_client = Mock(spec=["send_transaction"]) - self.federation_transport_client.send_transaction = AsyncMock() - hs = self.setup_test_homeserver( - federation_transport_client=self.federation_transport_client, - ) - - return hs - - def default_config(self) -> JsonDict: - config = super().default_config() - config["federation_sender_instances"] = None - return config - - def test_presence_simple(self) -> None: - "Test that sending a single presence update works" - - mock_send_transaction: AsyncMock = ( - self.federation_transport_client.send_transaction - ) - mock_send_transaction.return_value = {} - - sender = self.hs.get_federation_sender() - self.get_success( - sender.send_presence_to_destinations( - [UserPresenceState.default("@user:test")], - ["server"], - ) - ) - - self.pump() - - # expect a call to send_transaction - mock_send_transaction.assert_awaited_once() - - json_cb = mock_send_transaction.call_args[0][1] - data = json_cb() - self.assertEqual( - data["edus"], - [ - { - "edu_type": EduTypes.PRESENCE, - "content": { - "push": [ - { - "presence": "offline", - "user_id": "@user:test", - } - ] - }, - } - ], - ) - - def test_presence_batched(self) -> None: - """Test that sending lots of presence updates to a destination are - batched, rather than having them all sent in one EDU.""" - - mock_send_transaction: AsyncMock = ( - self.federation_transport_client.send_transaction - ) - mock_send_transaction.return_value = {} - - sender = self.hs.get_federation_sender() - - # We now send lots of presence updates to force the federation sender to - # batch the mup. - number_presence_updates_to_send = MAX_PRESENCE_STATES_PER_EDU * 2 - self.get_success( - sender.send_presence_to_destinations( - [ - UserPresenceState.default(f"@user{i}:test") - for i in range(number_presence_updates_to_send) - ], - ["server"], - ) - ) - - self.pump() - - # We should have seen at least one transcation be sent by now. - mock_send_transaction.assert_called() - - # We don't want to specify exactly how the presence EDUs get sent out, - # could be one per transaction or multiple per transaction. We just want - # to assert that a) each presence EDU has bounded number of updates, and - # b) that all updates get sent out. - presence_edus = [] - for transaction_call in mock_send_transaction.call_args_list: - json_cb = transaction_call[0][1] - data = json_cb() - - for edu in data["edus"]: - self.assertEqual(edu.get("edu_type"), EduTypes.PRESENCE) - presence_edus.append(edu) - - # A set of all user presence we see, this should end up matching the - # number we sent out above. - seen_users: Set[str] = set() - - for edu in presence_edus: - presence_states = edu["content"]["push"] - - # This is where we actually check that the number of presence - # updates is bounded. - self.assertLessEqual(len(presence_states), MAX_PRESENCE_STATES_PER_EDU) - - seen_users.update(p["user_id"] for p in presence_states) - - self.assertEqual(len(seen_users), number_presence_updates_to_send) - - class FederationSenderDevicesTestCases(HomeserverTestCase): """ Test federation sending to update devices. diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 88261450b1..8c9ef766d1 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 Matrix.org Federation C.I.C # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -67,23 +66,6 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") - def test_failed_edu_causes_500(self) -> None: - """If the EDU handler fails, /send should return a 500.""" - - async def failing_handler(_origin: str, _content: JsonDict) -> None: - raise Exception("bleh") - - self.hs.get_federation_registry().register_edu_handler( - "FAIL_EDU_TYPE", failing_handler - ) - - channel = self.make_signed_federation_request( - "PUT", - "/_matrix/federation/v1/send/txn", - {"edus": [{"edu_type": "FAIL_EDU_TYPE", "content": {}}]}, - ) - self.assertEqual(500, channel.code, channel.result) - class ServerACLsTestCase(unittest.TestCase): def test_blocked_server(self) -> None: diff --git a/tests/federation/transport/server/__init__.py b/tests/federation/transport/server/__init__.py
index dab387a504..3d833a2e44 100644 --- a/tests/federation/transport/server/__init__.py +++ b/tests/federation/transport/server/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
index 0e3b41ec4d..595349e70c 100644 --- a/tests/federation/transport/server/test__base.py +++ b/tests/federation/transport/server/test__base.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -147,10 +146,3 @@ class BaseFederationAuthorizationTests(unittest.TestCase): ), ("foo", "ed25519:1", "sig", "bar"), ) - # test that "optional whitespace(s)" (space and tabulation) are allowed between comma-separated auth-param components - self.assertEqual( - _parse_auth_header( - b'X-Matrix origin=foo , key="ed25519:1", sig="sig", destination="bar", extra_field=ignored' - ), - ("foo", "ed25519:1", "sig", "bar"), - ) diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
index 3d882f99f2..17e8ddb61e 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 166a01c1a2..2f9aefd2b6 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 Matrix.org Federation C.I.C # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 0237369998..e50f4bb4a3 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -59,14 +58,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/send/txn_id_1234/", content={ "edus": [ - { - "edu_type": EduTypes.DEVICE_LIST_UPDATE, - "content": { - "device_id": "QBUAZIFURK", - "stream_id": 0, - "user_id": "@user:id", - }, - }, + {"edu_type": EduTypes.DEVICE_LIST_UPDATE, "content": {"foo": "bar"}} ], "pdus": [], }, diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 9ff853a83d..09793c56a9 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 1eec0d43b7..2bce5afbd4 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c417431e85..af6e5a8a77 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index f41f7d36ad..36acbd790d 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index d7b54383db..1adc6d8854 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -21,13 +20,12 @@ from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership +from synapse.api.constants import AccountDataTypes from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest import admin -from synapse.rest.client import account, login, room +from synapse.rest.client import account, login from synapse.server import HomeServer from synapse.synapse_rust.push import PushRule -from synapse.types import UserID, create_requester from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -38,7 +36,6 @@ class DeactivateAccountTestCase(HomeserverTestCase): login.register_servlets, admin.register_servlets, account.register_servlets, - room.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -46,7 +43,6 @@ class DeactivateAccountTestCase(HomeserverTestCase): self.user = self.register_user("user", "pass") self.token = self.login("user", "pass") - self.handler = self.hs.get_room_member_handler() def _deactivate_my_account(self) -> None: """ @@ -344,142 +340,3 @@ class DeactivateAccountTestCase(HomeserverTestCase): self.assertEqual(req.code, 401, req) self.assertEqual(req.json_body["flows"], [{"stages": ["m.login.password"]}]) - - def test_deactivate_account_rejects_invites(self) -> None: - """ - Tests that deactivating an account rejects its invite memberships - """ - # Create another user and room just for the invitation - another_user = self.register_user("another_user", "pass") - token = self.login("another_user", "pass") - room_id = self.helper.create_room_as(another_user, is_public=False, tok=token) - - # Invite user to the created room - invite_event, _ = self.get_success( - self.handler.update_membership( - requester=create_requester(another_user), - target=UserID.from_string(self.user), - room_id=room_id, - action=Membership.INVITE, - ) - ) - - # Check that the invite exists - invite = self.get_success( - self._store.get_invited_rooms_for_local_user(self.user) - ) - self.assertEqual(invite[0].event_id, invite_event) - - # Deactivate the user - self._deactivate_my_account() - - # Check that the deactivated user has no invites in the room - after_deactivate_invite = self.get_success( - self._store.get_invited_rooms_for_local_user(self.user) - ) - self.assertEqual(len(after_deactivate_invite), 0) - - def test_deactivate_account_rejects_knocks(self) -> None: - """ - Tests that deactivating an account rejects its knock memberships - """ - # Create another user and room just for the invitation - another_user = self.register_user("another_user", "pass") - token = self.login("another_user", "pass") - room_id = self.helper.create_room_as( - another_user, - is_public=False, - tok=token, - ) - - # Allow room to be knocked at - self.helper.send_state( - room_id, - EventTypes.JoinRules, - {"join_rule": JoinRules.KNOCK}, - tok=token, - ) - - # Knock user at the created room - knock_event, _ = self.get_success( - self.handler.update_membership( - requester=create_requester(self.user), - target=UserID.from_string(self.user), - room_id=room_id, - action=Membership.KNOCK, - ) - ) - - # Check that the knock exists - knocks = self.get_success( - self._store.get_knocked_at_rooms_for_local_user(self.user) - ) - self.assertEqual(knocks[0].event_id, knock_event) - - # Deactivate the user - self._deactivate_my_account() - - # Check that the deactivated user has no knocks - after_deactivate_knocks = self.get_success( - self._store.get_knocked_at_rooms_for_local_user(self.user) - ) - self.assertEqual(len(after_deactivate_knocks), 0) - - def test_membership_is_redacted_upon_deactivation(self) -> None: - """ - Tests that room membership events are redacted if erasure is requested. - """ - # Create a room - room_id = self.helper.create_room_as( - self.user, - is_public=True, - tok=self.token, - ) - - # Change the displayname - membership_event, _ = self.get_success( - self.handler.update_membership( - requester=create_requester(self.user), - target=UserID.from_string(self.user), - room_id=room_id, - action=Membership.JOIN, - content={"displayname": "Hello World!"}, - ) - ) - - # Deactivate the account - self._deactivate_my_account() - - # Get the all membership event IDs - membership_event_ids = self.get_success( - self._store.get_membership_event_ids_for_user(self.user, room_id=room_id) - ) - - # Get the events incl. JSON - events = self.get_success(self._store.get_events_as_list(membership_event_ids)) - - # Validate that there is no displayname in any of the events - for event in events: - self.assertTrue("displayname" not in event.content) - - def test_rooms_forgotten_upon_deactivation(self) -> None: - """ - Tests that the user 'forgets' the rooms they left upon deactivation. - """ - # Create a room - room_id = self.helper.create_room_as( - self.user, - is_public=True, - tok=self.token, - ) - - # Deactivate the account - self._deactivate_my_account() - - # Get all of the user's forgotten rooms - forgotten_rooms = self.get_success( - self._store.get_forgotten_rooms_for_user(self.user) - ) - - # Validate that the created room is forgotten - self.assertTrue(room_id in forgotten_rooms) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 080e6a7028..4369a50281 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py
@@ -1,8 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 4a3e36ffde..03c360cca6 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py
@@ -1,8 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 Matrix.org Foundation C.I.C. -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -498,27 +496,19 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): self.denied_user_id = self.register_user("denied", "pass") self.denied_access_token = self.login("denied", "pass") - self.store = hs.get_datastores().main - def test_denied_without_publication_permission(self) -> None: """ - Try to create a room, register a valid alias for it, and publish it, + Try to create a room, register an alias for it, and publish it, as a user without permission to publish rooms. - The room should be created but not published. + (This is used as both a standalone test & as a helper function.) """ - room_id = self.helper.create_room_as( + self.helper.create_room_as( self.denied_user_id, tok=self.denied_access_token, extra_content=self.data, is_public=True, - expect_code=200, + expect_code=403, ) - res = self.get_success(self.store.get_room(room_id)) - assert res is not None - is_public, _ = res - - # room creation completes but room is not published to directory - self.assertEqual(is_public, False) def test_allowed_when_creating_private_room(self) -> None: """ @@ -536,8 +526,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): def test_allowed_with_publication_permission(self) -> None: """ - Try to create a room, register a valid alias for it, and publish it, + Try to create a room, register an alias for it, and publish it, as a user WITH permission to publish rooms. + (This is used as both a standalone test & as a helper function.) """ self.helper.create_room_as( self.allowed_user_id, @@ -549,26 +540,38 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): def test_denied_publication_with_invalid_alias(self) -> None: """ - Try to create a room, register an invalid alias for it, and publish it, + Try to create a room, register an alias for it, and publish it, as a user WITH permission to publish rooms. """ - room_id = self.helper.create_room_as( + self.helper.create_room_as( self.allowed_user_id, tok=self.allowed_access_token, extra_content={"room_alias_name": "foo"}, is_public=True, - expect_code=200, + expect_code=403, ) - # the room is created with the requested alias, but the room is not published - res = self.get_success(self.store.get_room(room_id)) - assert res is not None - is_public, _ = res + def test_can_create_as_private_room_after_rejection(self) -> None: + """ + After failing to publish a room with an alias as a user without publish permission, + retry as the same user, but without publishing the room. - self.assertFalse(is_public) + This should pass, but used to fail because the alias was registered by the first + request, even though the room creation was denied. + """ + self.test_denied_without_publication_permission() + self.test_allowed_when_creating_private_room() + + def test_can_create_with_permission_after_rejection(self) -> None: + """ + After failing to publish a room with an alias as a user without publish permission, + retry as someone with permission, using the same alias. - aliases = self.get_success(self.store.get_aliases_for_room(room_id)) - self.assertEqual(aliases[0], "#foo:test") + This also used to fail because of the alias having been registered by the first + request, leaving it unavailable for any other user's new rooms. + """ + self.test_denied_without_publication_permission() + self.test_allowed_with_publication_permission() class TestRoomListSearchDisabled(unittest.HomeserverTestCase): diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 8a3dfdcf75..aa375fa218 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py
@@ -1,8 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -43,7 +41,9 @@ from tests.unittest import override_config class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.appservice_api = mock.AsyncMock() - return self.setup_test_homeserver(application_service_api=self.appservice_api) + return self.setup_test_homeserver( + federation_client=mock.Mock(), application_service_api=self.appservice_api + ) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_keys_handler() @@ -1099,56 +1099,6 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_has_different_keys(self) -> None: - """check that has_different_keys returns True when the keys provided are different to what - is in the database.""" - local_user = "@boris:" + self.hs.hostname - keys1 = { - "master_key": { - # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 - "user_id": local_user, - "usage": ["master"], - "keys": { - "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" - }, - } - } - self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) - is_different = self.get_success( - self.handler.has_different_keys( - local_user, - { - "master_key": keys1["master_key"], - }, - ) - ) - self.assertEqual(is_different, False) - # change the usage => different keys - keys1["master_key"]["usage"] = ["develop"] - is_different = self.get_success( - self.handler.has_different_keys( - local_user, - { - "master_key": keys1["master_key"], - }, - ) - ) - self.assertEqual(is_different, True) - keys1["master_key"]["usage"] = ["master"] # reset - # change the key => different keys - keys1["master_key"]["keys"] = { - "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unIc0rncs": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unIc0rncs" - } - is_different = self.get_success( - self.handler.has_different_keys( - local_user, - { - "master_key": keys1["master_key"], - }, - ) - ) - self.assertEqual(is_different, True) - def test_query_devices_remote_sync(self) -> None: """Tests that querying keys for a remote user that we share a room with, but haven't yet fetched the keys for, returns the cross signing keys @@ -1222,61 +1172,6 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_query_devices_remote_down(self) -> None: - """Tests that querying keys for a remote user on an unreachable server returns - results in the "failures" property - """ - - remote_user_id = "@test:other" - local_user_id = "@test:test" - - # The backoff code treats time zero as special - self.reactor.advance(5) - - self.hs.get_federation_http_client().agent.request = mock.AsyncMock( # type: ignore[method-assign] - side_effect=Exception("boop") - ) - - e2e_handler = self.hs.get_e2e_keys_handler() - - query_result = self.get_success( - e2e_handler.query_devices( - { - "device_keys": {remote_user_id: []}, - }, - timeout=10, - from_user_id=local_user_id, - from_device_id="some_device_id", - ) - ) - - self.assertEqual( - query_result["failures"], - { - "other": { - "message": "Failed to send request: Exception: boop", - "status": 503, - } - }, - ) - - # Do it again: we should hit the backoff - query_result = self.get_success( - e2e_handler.query_devices( - { - "device_keys": {remote_user_id: []}, - }, - timeout=10, - from_user_id=local_user_id, - from_device_id="some_device_id", - ) - ) - - self.assertEqual( - query_result["failures"], - {"other": {"message": "Not ready for retry", "status": 503}}, - ) - @parameterized.expand( [ # The remote homeserver's response indicates that this user has 0/1/2 devices. diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 3ec46402b7..3217639176 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py
@@ -1,8 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 Matrix.org Foundation C.I.C. -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 3fe5b0a1b4..8e2ad42ab8 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -483,7 +482,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): event.room_version, ), exc=LimitExceededError, - by=0.5, ) def _build_and_send_join_event( diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 1b83aea579..938c5dec60 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 76ab83d1f7..802f96ed48 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -24,7 +23,6 @@ from typing import Tuple from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.rest import admin @@ -52,15 +50,11 @@ class EventCreationTestCase(unittest.HomeserverTestCase): persistence = self.hs.get_storage_controllers().persistence assert persistence is not None self._persist_event_storage_controller = persistence - self.store = self.hs.get_datastores().main self.user_id = self.register_user("tester", "foobar") device_id = "dev-1" access_token = self.login("tester", "foobar", device_id=device_id) self.room_id = self.helper.create_room_as(self.user_id, tok=access_token) - self.private_room_id = self.helper.create_room_as( - self.user_id, tok=access_token, extra_content={"preset": "private_chat"} - ) self.requester = create_requester(self.user_id, device_id=device_id) @@ -290,41 +284,6 @@ class EventCreationTestCase(unittest.HomeserverTestCase): AssertionError, ) - def test_call_invite_event_creation_fails_in_public_room(self) -> None: - # get prev_events for room - prev_events = self.get_success( - self.store.get_prev_events_for_room(self.room_id) - ) - - # the invite in a public room should fail - self.get_failure( - self.handler.create_event( - self.requester, - { - "type": EventTypes.CallInvite, - "room_id": self.room_id, - "sender": self.requester.user.to_string(), - }, - prev_event_ids=prev_events, - auth_event_ids=prev_events, - ), - SynapseError, - ) - - # but a call invite in a private room should succeed - self.get_success( - self.handler.create_event( - self.requester, - { - "type": EventTypes.CallInvite, - "room_id": self.private_room_id, - "sender": self.requester.user.to_string(), - }, - prev_event_ids=prev_events, - auth_event_ids=prev_events, - ) - ) - class ServerAclValidationTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index 036c539db2..b9761d806d 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -541,8 +540,6 @@ class MSC3861OAuthDelegation(HomeserverTestCase): self.assertEqual(channel.code, 200, channel.json_body) - # Try uploading *different* keys; it should cause a 501 error. - keys_upload_body = self.make_device_keys(USER_ID, DEVICE) channel = self.make_request( "POST", "/_matrix/client/v3/keys/device_signing/upload", diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a81501979d..df30a80734 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 Quentin Gliech # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ed203eb299..b3248fa491 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index cc630d606c..c68e8c6631 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index cb1c6fbb80..d70026c31e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 7c5bec2b76..3c14dfa0a8 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 Šimon Brandner <simon.bra.ag@gmail.com> # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 92487692db..36255270ed 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 213a66ed1a..3e28117e2c 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py
@@ -70,7 +70,6 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): action=Membership.JOIN, ), LimitExceededError, - by=0.5, ) @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) @@ -207,7 +206,6 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): remote_room_hosts=[self.OTHER_SERVER_NAME], ), LimitExceededError, - by=0.5, ) # TODO: test that remote joins to a room are rate limited. @@ -275,7 +273,6 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCa action=Membership.JOIN, ), LimitExceededError, - by=0.5, ) # Try to join as Chris on the original worker. Should get denied because Alice @@ -288,7 +285,6 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCa action=Membership.JOIN, ), LimitExceededError, - by=0.5, ) @@ -407,24 +403,3 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase): self.assertFalse( self.get_success(self.store.did_forget(self.alice, self.room_id)) ) - - def test_deduplicate_joins(self) -> None: - """ - Test that calling /join multiple times does not store a new state group. - """ - - self.helper.join(self.room_id, user=self.bob, tok=self.bob_token) - - sql = "SELECT COUNT(*) FROM state_groups WHERE room_id = ?" - rows = self.get_success( - self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id) - ) - initial_count = rows[0][0] - - self.helper.join(self.room_id, user=self.bob, tok=self.bob_token) - rows = self.get_success( - self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id) - ) - new_count = rows[0][0] - - self.assertEqual(initial_count, new_count) diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index 244a4e7689..929772c412 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 6ab8fda6e7..9d7cd4ee6f 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index cedcea27d9..73ea7fc346 100644 --- a/tests/handlers/test_send_email.py +++ b/tests/handlers/test_send_email.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py deleted file mode 100644
index 96da47f3b9..0000000000 --- a/tests/handlers/test_sliding_sync.py +++ /dev/null
@@ -1,4567 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# -import logging -from copy import deepcopy -from typing import Dict, List, Optional -from unittest.mock import patch - -from parameterized import parameterized - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.constants import ( - AccountDataTypes, - EventContentFields, - EventTypes, - JoinRules, - Membership, - RoomTypes, -) -from synapse.api.room_versions import RoomVersions -from synapse.events import StrippedStateEvent, make_event_from_dict -from synapse.events.snapshot import EventContext -from synapse.handlers.sliding_sync import ( - RoomSyncConfig, - StateValues, - _RoomMembershipForUser, -) -from synapse.rest import admin -from synapse.rest.client import knock, login, room -from synapse.server import HomeServer -from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.types import JsonDict, StreamToken, UserID -from synapse.types.handlers import SlidingSyncConfig -from synapse.util import Clock - -from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.unittest import HomeserverTestCase, TestCase - -logger = logging.getLogger(__name__) - - -class RoomSyncConfigTestCase(TestCase): - def _assert_room_config_equal( - self, - actual: RoomSyncConfig, - expected: RoomSyncConfig, - message_prefix: Optional[str] = None, - ) -> None: - self.assertEqual(actual.timeline_limit, expected.timeline_limit, message_prefix) - - # `self.assertEqual(...)` works fine to catch differences but the output is - # almost impossible to read because of the way it truncates the output and the - # order doesn't actually matter. - self.assertCountEqual( - actual.required_state_map, expected.required_state_map, message_prefix - ) - for event_type, expected_state_keys in expected.required_state_map.items(): - self.assertCountEqual( - actual.required_state_map[event_type], - expected_state_keys, - f"{message_prefix}: Mismatch for {event_type}", - ) - - @parameterized.expand( - [ - ( - "from_list_config", - """ - Test that we can convert a `SlidingSyncConfig.SlidingSyncList` to a - `RoomSyncConfig`. - """, - # Input - SlidingSyncConfig.SlidingSyncList( - timeline_limit=10, - required_state=[ - (EventTypes.Name, ""), - (EventTypes.Member, "@foo"), - (EventTypes.Member, "@bar"), - (EventTypes.Member, "@baz"), - (EventTypes.CanonicalAlias, ""), - ], - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Name: {""}, - EventTypes.Member: { - "@foo", - "@bar", - "@baz", - }, - EventTypes.CanonicalAlias: {""}, - }, - ), - ), - ( - "from_room_subscription", - """ - Test that we can convert a `SlidingSyncConfig.RoomSubscription` to a - `RoomSyncConfig`. - """, - # Input - SlidingSyncConfig.RoomSubscription( - timeline_limit=10, - required_state=[ - (EventTypes.Name, ""), - (EventTypes.Member, "@foo"), - (EventTypes.Member, "@bar"), - (EventTypes.Member, "@baz"), - (EventTypes.CanonicalAlias, ""), - ], - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Name: {""}, - EventTypes.Member: { - "@foo", - "@bar", - "@baz", - }, - EventTypes.CanonicalAlias: {""}, - }, - ), - ), - ( - "wildcard", - """ - Test that a wildcard (*) for both the `event_type` and `state_key` will override - all other values. - - Note: MSC3575 describes different behavior to how we're handling things here but - since it's not wrong to return more state than requested (`required_state` is - just the minimum requested), it doesn't matter if we include things that the - client wanted excluded. This complexity is also under scrutiny, see - https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1185109050 - - > One unique exception is when you request all state events via ["*", "*"]. When used, - > all state events are returned by default, and additional entries FILTER OUT the returned set - > of state events. These additional entries cannot use '*' themselves. - > For example, ["*", "*"], ["m.room.member", "@alice:example.com"] will _exclude_ every m.room.member - > event _except_ for @alice:example.com, and include every other state event. - > In addition, ["*", "*"], ["m.space.child", "*"] is an error, the m.space.child filter is not - > required as it would have been returned anyway. - > - > -- MSC3575 (https://github.com/matrix-org/matrix-spec-proposals/pull/3575) - """, - # Input - SlidingSyncConfig.SlidingSyncList( - timeline_limit=10, - required_state=[ - (EventTypes.Name, ""), - (StateValues.WILDCARD, StateValues.WILDCARD), - (EventTypes.Member, "@foo"), - (EventTypes.CanonicalAlias, ""), - ], - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - StateValues.WILDCARD: {StateValues.WILDCARD}, - }, - ), - ), - ( - "wildcard_type", - """ - Test that a wildcard (*) as a `event_type` will override all other values for the - same `state_key`. - """, - # Input - SlidingSyncConfig.SlidingSyncList( - timeline_limit=10, - required_state=[ - (EventTypes.Name, ""), - (StateValues.WILDCARD, ""), - (EventTypes.Member, "@foo"), - (EventTypes.CanonicalAlias, ""), - ], - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - StateValues.WILDCARD: {""}, - EventTypes.Member: {"@foo"}, - }, - ), - ), - ( - "multiple_wildcard_type", - """ - Test that multiple wildcard (*) as a `event_type` will override all other values - for the same `state_key`. - """, - # Input - SlidingSyncConfig.SlidingSyncList( - timeline_limit=10, - required_state=[ - (EventTypes.Name, ""), - (StateValues.WILDCARD, ""), - (EventTypes.Member, "@foo"), - (StateValues.WILDCARD, "@foo"), - ("org.matrix.personal_count", "@foo"), - (EventTypes.Member, "@bar"), - (EventTypes.CanonicalAlias, ""), - ], - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - StateValues.WILDCARD: { - "", - "@foo", - }, - EventTypes.Member: {"@bar"}, - }, - ), - ), - ( - "wildcard_state_key", - """ - Test that a wildcard (*) as a `state_key` will override all other values for the - same `event_type`. - """, - # Input - SlidingSyncConfig.SlidingSyncList( - timeline_limit=10, - required_state=[ - (EventTypes.Name, ""), - (EventTypes.Member, "@foo"), - (EventTypes.Member, StateValues.WILDCARD), - (EventTypes.Member, "@bar"), - (EventTypes.Member, StateValues.LAZY), - (EventTypes.Member, "@baz"), - (EventTypes.CanonicalAlias, ""), - ], - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Name: {""}, - EventTypes.Member: { - StateValues.WILDCARD, - }, - EventTypes.CanonicalAlias: {""}, - }, - ), - ), - ( - "wildcard_merge", - """ - Test that a wildcard (*) entries for the `event_type` and another one for - `state_key` will play together. - """, - # Input - SlidingSyncConfig.SlidingSyncList( - timeline_limit=10, - required_state=[ - (EventTypes.Name, ""), - (StateValues.WILDCARD, ""), - (EventTypes.Member, "@foo"), - (EventTypes.Member, StateValues.WILDCARD), - (EventTypes.Member, "@bar"), - (EventTypes.CanonicalAlias, ""), - ], - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - StateValues.WILDCARD: {""}, - EventTypes.Member: {StateValues.WILDCARD}, - }, - ), - ), - ( - "wildcard_merge2", - """ - Test that an all wildcard ("*", "*") entry will override any other - values (including other wildcards). - """, - # Input - SlidingSyncConfig.SlidingSyncList( - timeline_limit=10, - required_state=[ - (EventTypes.Name, ""), - (StateValues.WILDCARD, ""), - (EventTypes.Member, StateValues.WILDCARD), - (EventTypes.Member, "@foo"), - # One of these should take precedence over everything else - (StateValues.WILDCARD, StateValues.WILDCARD), - (StateValues.WILDCARD, StateValues.WILDCARD), - (EventTypes.CanonicalAlias, ""), - ], - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - StateValues.WILDCARD: {StateValues.WILDCARD}, - }, - ), - ), - ( - "lazy_members", - """ - `$LAZY` room members should just be another additional key next to other - explicit keys. We will unroll the special `$LAZY` meaning later. - """, - # Input - SlidingSyncConfig.SlidingSyncList( - timeline_limit=10, - required_state=[ - (EventTypes.Name, ""), - (EventTypes.Member, "@foo"), - (EventTypes.Member, "@bar"), - (EventTypes.Member, StateValues.LAZY), - (EventTypes.Member, "@baz"), - (EventTypes.CanonicalAlias, ""), - ], - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Name: {""}, - EventTypes.Member: { - "@foo", - "@bar", - StateValues.LAZY, - "@baz", - }, - EventTypes.CanonicalAlias: {""}, - }, - ), - ), - ] - ) - def test_from_room_config( - self, - _test_label: str, - _test_description: str, - room_params: SlidingSyncConfig.CommonRoomParameters, - expected_room_sync_config: RoomSyncConfig, - ) -> None: - """ - Test `RoomSyncConfig.from_room_config(room_params)` will result in the `expected_room_sync_config`. - """ - room_sync_config = RoomSyncConfig.from_room_config(room_params) - - self._assert_room_config_equal( - room_sync_config, - expected_room_sync_config, - ) - - @parameterized.expand( - [ - ( - "no_direct_overlap", - # A - RoomSyncConfig( - timeline_limit=9, - required_state_map={ - EventTypes.Name: {""}, - EventTypes.Member: { - "@foo", - "@bar", - }, - }, - ), - # B - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Member: { - StateValues.LAZY, - "@baz", - }, - EventTypes.CanonicalAlias: {""}, - }, - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Name: {""}, - EventTypes.Member: { - "@foo", - "@bar", - StateValues.LAZY, - "@baz", - }, - EventTypes.CanonicalAlias: {""}, - }, - ), - ), - ( - "wildcard_overlap", - # A - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - StateValues.WILDCARD: {StateValues.WILDCARD}, - }, - ), - # B - RoomSyncConfig( - timeline_limit=9, - required_state_map={ - EventTypes.Dummy: {StateValues.WILDCARD}, - StateValues.WILDCARD: {"@bar"}, - EventTypes.Member: {"@foo"}, - }, - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - StateValues.WILDCARD: {StateValues.WILDCARD}, - }, - ), - ), - ( - "state_type_wildcard_overlap", - # A - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Dummy: {"dummy"}, - StateValues.WILDCARD: { - "", - "@foo", - }, - EventTypes.Member: {"@bar"}, - }, - ), - # B - RoomSyncConfig( - timeline_limit=9, - required_state_map={ - EventTypes.Dummy: {"dummy2"}, - StateValues.WILDCARD: { - "", - "@bar", - }, - EventTypes.Member: {"@foo"}, - }, - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Dummy: { - "dummy", - "dummy2", - }, - StateValues.WILDCARD: { - "", - "@foo", - "@bar", - }, - }, - ), - ), - ( - "state_key_wildcard_overlap", - # A - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Dummy: {"dummy"}, - EventTypes.Member: {StateValues.WILDCARD}, - "org.matrix.flowers": {StateValues.WILDCARD}, - }, - ), - # B - RoomSyncConfig( - timeline_limit=9, - required_state_map={ - EventTypes.Dummy: {StateValues.WILDCARD}, - EventTypes.Member: {StateValues.WILDCARD}, - "org.matrix.flowers": {"tulips"}, - }, - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Dummy: {StateValues.WILDCARD}, - EventTypes.Member: {StateValues.WILDCARD}, - "org.matrix.flowers": {StateValues.WILDCARD}, - }, - ), - ), - ( - "state_type_and_state_key_wildcard_merge", - # A - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Dummy: {"dummy"}, - StateValues.WILDCARD: { - "", - "@foo", - }, - EventTypes.Member: {"@bar"}, - }, - ), - # B - RoomSyncConfig( - timeline_limit=9, - required_state_map={ - EventTypes.Dummy: {"dummy2"}, - StateValues.WILDCARD: {""}, - EventTypes.Member: {StateValues.WILDCARD}, - }, - ), - # Expected - RoomSyncConfig( - timeline_limit=10, - required_state_map={ - EventTypes.Dummy: { - "dummy", - "dummy2", - }, - StateValues.WILDCARD: { - "", - "@foo", - }, - EventTypes.Member: {StateValues.WILDCARD}, - }, - ), - ), - ] - ) - def test_combine_room_sync_config( - self, - _test_label: str, - a: RoomSyncConfig, - b: RoomSyncConfig, - expected: RoomSyncConfig, - ) -> None: - """ - Combine A into B and B into A to make sure we get the same result. - """ - # Since we're mutating these in place, make a copy for each of our trials - room_sync_config_a = deepcopy(a) - room_sync_config_b = deepcopy(b) - - # Combine B into A - room_sync_config_a.combine_room_sync_config(room_sync_config_b) - - self._assert_room_config_equal(room_sync_config_a, expected, "B into A") - - # Since we're mutating these in place, make a copy for each of our trials - room_sync_config_a = deepcopy(a) - room_sync_config_b = deepcopy(b) - - # Combine A into B - room_sync_config_b.combine_room_sync_config(room_sync_config_a) - - self._assert_room_config_equal(room_sync_config_b, expected, "A into B") - - -class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): - """ - Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it returns - the correct list of rooms IDs. - """ - - servlets = [ - admin.register_servlets, - knock.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def default_config(self) -> JsonDict: - config = super().default_config() - # Enable sliding sync - config["experimental_features"] = {"msc3575_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sliding_sync_handler = self.hs.get_sliding_sync_handler() - self.store = self.hs.get_datastores().main - self.event_sources = hs.get_event_sources() - self.storage_controllers = hs.get_storage_controllers() - - def test_no_rooms(self) -> None: - """ - Test when the user has never joined any rooms before - """ - user1_id = self.register_user("user1", "pass") - # user1_tok = self.login(user1_id, "pass") - - now_token = self.event_sources.get_current_token() - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=now_token, - to_token=now_token, - ) - ) - - self.assertEqual(room_id_results.keys(), set()) - - def test_get_newly_joined_room(self) -> None: - """ - Test that rooms that the user has newly_joined show up. newly_joined is when you - join after the `from_token` and <= `to_token`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room_token = self.event_sources.get_current_token() - - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response = self.helper.join(room_id, user1_id, tok=user1_tok) - - after_room_token = self.event_sources.get_current_token() - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room_token, - to_token=after_room_token, - ) - ) - - self.assertEqual(room_id_results.keys(), {room_id}) - # It should be pointing to the join event (latest membership event in the - # from/to range) - self.assertEqual( - room_id_results[room_id].event_id, - join_response["event_id"], - ) - self.assertEqual(room_id_results[room_id].membership, Membership.JOIN) - # We should be considered `newly_joined` because we joined during the token - # range - self.assertEqual(room_id_results[room_id].newly_joined, True) - self.assertEqual(room_id_results[room_id].newly_left, False) - - def test_get_already_joined_room(self) -> None: - """ - Test that rooms that the user is already joined show up. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response = self.helper.join(room_id, user1_id, tok=user1_tok) - - after_room_token = self.event_sources.get_current_token() - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_room_token, - to_token=after_room_token, - ) - ) - - self.assertEqual(room_id_results.keys(), {room_id}) - # It should be pointing to the join event (latest membership event in the - # from/to range) - self.assertEqual( - room_id_results[room_id].event_id, - join_response["event_id"], - ) - self.assertEqual(room_id_results[room_id].membership, Membership.JOIN) - # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id].newly_joined, False) - self.assertEqual(room_id_results[room_id].newly_left, False) - - def test_get_invited_banned_knocked_room(self) -> None: - """ - Test that rooms that the user is invited to, banned from, and knocked on show - up. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room_token = self.event_sources.get_current_token() - - # Setup the invited room (user2 invites user1 to the room) - invited_room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - invite_response = self.helper.invite( - invited_room_id, targ=user1_id, tok=user2_tok - ) - - # Setup the ban room (user2 bans user1 from the room) - ban_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok, is_public=True - ) - self.helper.join(ban_room_id, user1_id, tok=user1_tok) - ban_response = self.helper.ban( - ban_room_id, src=user2_id, targ=user1_id, tok=user2_tok - ) - - # Setup the knock room (user1 knocks on the room) - knock_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok, room_version=RoomVersions.V7.identifier - ) - self.helper.send_state( - knock_room_id, - EventTypes.JoinRules, - {"join_rule": JoinRules.KNOCK}, - tok=user2_tok, - ) - # User1 knocks on the room - knock_channel = self.make_request( - "POST", - "/_matrix/client/r0/knock/%s" % (knock_room_id,), - b"{}", - user1_tok, - ) - self.assertEqual(knock_channel.code, 200, knock_channel.result) - knock_room_membership_state_event = self.get_success( - self.storage_controllers.state.get_current_state_event( - knock_room_id, EventTypes.Member, user1_id - ) - ) - assert knock_room_membership_state_event is not None - - after_room_token = self.event_sources.get_current_token() - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room_token, - to_token=after_room_token, - ) - ) - - # Ensure that the invited, ban, and knock rooms show up - self.assertEqual( - room_id_results.keys(), - { - invited_room_id, - ban_room_id, - knock_room_id, - }, - ) - # It should be pointing to the the respective membership event (latest - # membership event in the from/to range) - self.assertEqual( - room_id_results[invited_room_id].event_id, - invite_response["event_id"], - ) - self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE) - self.assertEqual(room_id_results[invited_room_id].newly_joined, False) - self.assertEqual(room_id_results[invited_room_id].newly_left, False) - - self.assertEqual( - room_id_results[ban_room_id].event_id, - ban_response["event_id"], - ) - self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN) - self.assertEqual(room_id_results[ban_room_id].newly_joined, False) - self.assertEqual(room_id_results[ban_room_id].newly_left, False) - - self.assertEqual( - room_id_results[knock_room_id].event_id, - knock_room_membership_state_event.event_id, - ) - self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK) - self.assertEqual(room_id_results[knock_room_id].newly_joined, False) - self.assertEqual(room_id_results[knock_room_id].newly_left, False) - - def test_get_kicked_room(self) -> None: - """ - Test that a room that the user was kicked from still shows up. When the user - comes back to their client, they should see that they were kicked. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Setup the kick room (user2 kicks user1 from the room) - kick_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok, is_public=True - ) - self.helper.join(kick_room_id, user1_id, tok=user1_tok) - # Kick user1 from the room - kick_response = self.helper.change_membership( - room=kick_room_id, - src=user2_id, - targ=user1_id, - tok=user2_tok, - membership=Membership.LEAVE, - extra_data={ - "reason": "Bad manners", - }, - ) - - after_kick_token = self.event_sources.get_current_token() - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_kick_token, - to_token=after_kick_token, - ) - ) - - # The kicked room should show up - self.assertEqual(room_id_results.keys(), {kick_room_id}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[kick_room_id].event_id, - kick_response["event_id"], - ) - self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE) - self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id) - # We should *NOT* be `newly_joined` because we were not joined at the the time - # of the `to_token`. - self.assertEqual(room_id_results[kick_room_id].newly_joined, False) - self.assertEqual(room_id_results[kick_room_id].newly_left, False) - - def test_forgotten_rooms(self) -> None: - """ - Forgotten rooms do not show up even if we forget after the from/to range. - - Ideally, we would be able to track when the `/forget` happens and apply it - accordingly in the token range but the forgotten flag is only an extra bool in - the `room_memberships` table. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Setup a normal room that we leave. This won't show up in the sync response - # because we left it before our token but is good to check anyway. - leave_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok, is_public=True - ) - self.helper.join(leave_room_id, user1_id, tok=user1_tok) - self.helper.leave(leave_room_id, user1_id, tok=user1_tok) - - # Setup the ban room (user2 bans user1 from the room) - ban_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok, is_public=True - ) - self.helper.join(ban_room_id, user1_id, tok=user1_tok) - self.helper.ban(ban_room_id, src=user2_id, targ=user1_id, tok=user2_tok) - - # Setup the kick room (user2 kicks user1 from the room) - kick_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok, is_public=True - ) - self.helper.join(kick_room_id, user1_id, tok=user1_tok) - # Kick user1 from the room - self.helper.change_membership( - room=kick_room_id, - src=user2_id, - targ=user1_id, - tok=user2_tok, - membership=Membership.LEAVE, - extra_data={ - "reason": "Bad manners", - }, - ) - - before_room_forgets = self.event_sources.get_current_token() - - # Forget the room after we already have our tokens. This doesn't change - # the membership event itself but will mark it internally in Synapse - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{leave_room_id}/forget", - content={}, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{ban_room_id}/forget", - content={}, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{kick_room_id}/forget", - content={}, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room_forgets, - to_token=before_room_forgets, - ) - ) - - # We shouldn't see the room because it was forgotten - self.assertEqual(room_id_results.keys(), set()) - - def test_newly_left_rooms(self) -> None: - """ - Test that newly_left are marked properly - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Leave before we calculate the `from_token` - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Leave during the from_token/to_token range (newly_left) - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) - leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok) - - after_room2_token = self.event_sources.get_current_token() - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_room1_token, - to_token=after_room2_token, - ) - ) - - self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) - - self.assertEqual( - room_id_results[room_id1].event_id, - leave_response1["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined` or `newly_left` because that happened before - # the from/to range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - self.assertEqual( - room_id_results[room_id2].event_id, - leave_response2["event_id"], - ) - self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined` because we are instead `newly_left` - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, True) - - def test_no_joins_after_to_token(self) -> None: - """ - Rooms we join after the `to_token` should *not* show up. See condition "1b)" - comments in the `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Room join after our `to_token` shouldn't show up - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id2, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room1_token, - to_token=after_room1_token, - ) - ) - - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - join_response1["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_join_during_range_and_left_room_after_to_token(self) -> None: - """ - Room still shows up if we left the room but were joined during the - from_token/to_token. See condition "1a)" comments in the - `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Leave the room after we already have our tokens - leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room1_token, - to_token=after_room1_token, - ) - ) - - # We should still see the room because we were joined during the - # from_token/to_token time period. - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - join_response["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response": join_response["event_id"], - "leave_response": leave_response["event_id"], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_join_before_range_and_left_room_after_to_token(self) -> None: - """ - Room still shows up if we left the room but were joined before the `from_token` - so it should show up. See condition "1a)" comments in the - `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Leave the room after we already have our tokens - leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_room1_token, - to_token=after_room1_token, - ) - ) - - # We should still see the room because we were joined before the `from_token` - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - join_response["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response": join_response["event_id"], - "leave_response": leave_response["event_id"], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_kicked_before_range_and_left_after_to_token(self) -> None: - """ - Room still shows up if we left the room but were kicked before the `from_token` - so it should show up. See condition "1a)" comments in the - `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Setup the kick room (user2 kicks user1 from the room) - kick_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok, is_public=True - ) - join_response1 = self.helper.join(kick_room_id, user1_id, tok=user1_tok) - # Kick user1 from the room - kick_response = self.helper.change_membership( - room=kick_room_id, - src=user2_id, - targ=user1_id, - tok=user2_tok, - membership=Membership.LEAVE, - extra_data={ - "reason": "Bad manners", - }, - ) - - after_kick_token = self.event_sources.get_current_token() - - # Leave the room after we already have our tokens - # - # We have to join before we can leave (leave -> leave isn't a valid transition - # or at least it doesn't work in Synapse, 403 forbidden) - join_response2 = self.helper.join(kick_room_id, user1_id, tok=user1_tok) - leave_response = self.helper.leave(kick_room_id, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_kick_token, - to_token=after_kick_token, - ) - ) - - # We shouldn't see the room because it was forgotten - self.assertEqual(room_id_results.keys(), {kick_room_id}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[kick_room_id].event_id, - kick_response["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response1": join_response1["event_id"], - "kick_response": kick_response["event_id"], - "join_response2": join_response2["event_id"], - "leave_response": leave_response["event_id"], - } - ), - ) - self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE) - self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id) - # We should *NOT* be `newly_joined` because we were kicked - self.assertEqual(room_id_results[kick_room_id].newly_joined, False) - self.assertEqual(room_id_results[kick_room_id].newly_left, False) - - def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None: - """ - Newly left room should show up. But we're also testing that joining and leaving - after the `to_token` doesn't mess with the results. See condition "2)" and "1a)" - comments in the `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - # Join and leave the room during the from/to range - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Join and leave the room after we already have our tokens - join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room1_token, - to_token=after_room1_token, - ) - ) - - # Room should still show up because it's newly_left during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - leave_response1["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response1": join_response1["event_id"], - "leave_response1": leave_response1["event_id"], - "join_response2": join_response2["event_id"], - "leave_response2": leave_response2["event_id"], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined` because we are actually `newly_left` during - # the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, True) - - def test_newly_left_during_range_and_join_after_to_token(self) -> None: - """ - Newly left room should show up. But we're also testing that joining after the - `to_token` doesn't mess with the results. See condition "2)" and "1b)" comments - in the `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - # Join and leave the room during the from/to range - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Join the room after we already have our tokens - join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room1_token, - to_token=after_room1_token, - ) - ) - - # Room should still show up because it's newly_left during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - leave_response1["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response1": join_response1["event_id"], - "leave_response1": leave_response1["event_id"], - "join_response2": join_response2["event_id"], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined` because we are actually `newly_left` during - # the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, True) - - def test_no_from_token(self) -> None: - """ - Test that if we don't provide a `from_token`, we get all the rooms that we had - membership in up to the `to_token`. - - Providing `from_token` only really has the effect that it marks rooms as - `newly_left` in the response. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - - # Join room1 - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Join and leave the room2 before the `to_token` - self.helper.join(room_id2, user1_id, tok=user1_tok) - leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Join the room2 after we already have our tokens - self.helper.join(room_id2, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=None, - to_token=after_room1_token, - ) - ) - - # Only rooms we were joined to before the `to_token` should show up - self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) - - # Room1 - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - join_response1["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should *NOT* be `newly_joined`/`newly_left` because there is no - # `from_token` to define a "live" range to compare against - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - # Room2 - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id2].event_id, - leave_response2["event_id"], - ) - self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined`/`newly_left` because there is no - # `from_token` to define a "live" range to compare against - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, False) - - def test_from_token_ahead_of_to_token(self) -> None: - """ - Test when the provided `from_token` comes after the `to_token`. We should - basically expect the same result as having no `from_token`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - room_id4 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - - # Join room1 before `to_token` - join_room1_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Join and leave the room2 before `to_token` - _join_room2_response1 = self.helper.join(room_id2, user1_id, tok=user1_tok) - leave_room2_response1 = self.helper.leave(room_id2, user1_id, tok=user1_tok) - - # Note: These are purposely swapped. The `from_token` should come after - # the `to_token` in this test - to_token = self.event_sources.get_current_token() - - # Join room2 after `to_token` - _join_room2_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok) - - # -------- - - # Join room3 after `to_token` - _join_room3_response1 = self.helper.join(room_id3, user1_id, tok=user1_tok) - - # Join and leave the room4 after `to_token` - _join_room4_response1 = self.helper.join(room_id4, user1_id, tok=user1_tok) - _leave_room4_response1 = self.helper.leave(room_id4, user1_id, tok=user1_tok) - - # Note: These are purposely swapped. The `from_token` should come after the - # `to_token` in this test - from_token = self.event_sources.get_current_token() - - # Join the room4 after we already have our tokens - self.helper.join(room_id4, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=from_token, - to_token=to_token, - ) - ) - - # In the "current" state snapshot, we're joined to all of the rooms but in the - # from/to token range... - self.assertIncludes( - room_id_results.keys(), - { - # Included because we were joined before both tokens - room_id1, - # Included because we had membership before the to_token - room_id2, - # Excluded because we joined after the `to_token` - # room_id3, - # Excluded because we joined after the `to_token` - # room_id4, - }, - exact=True, - ) - - # Room1 - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - join_room1_response1["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should *NOT* be `newly_joined`/`newly_left` because we joined `room1` - # before either of the tokens - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - # Room2 - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id2].event_id, - leave_room2_response1["event_id"], - ) - self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined`/`newly_left` because we joined and left - # `room1` before either of the tokens - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, False) - - def test_leave_before_range_and_join_leave_after_to_token(self) -> None: - """ - Test old left rooms. But we're also testing that joining and leaving after the - `to_token` doesn't mess with the results. See condition "1a)" comments in the - `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - # Join and leave the room before the from/to range - self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Join and leave the room after we already have our tokens - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.leave(room_id1, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_room1_token, - to_token=after_room1_token, - ) - ) - - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - leave_response["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined`/`newly_left` because we joined and left - # `room1` before either of the tokens - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_leave_before_range_and_join_after_to_token(self) -> None: - """ - Test old left room. But we're also testing that joining after the `to_token` - doesn't mess with the results. See condition "1b)" comments in the - `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - # Join and leave the room before the from/to range - self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Join the room after we already have our tokens - self.helper.join(room_id1, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_room1_token, - to_token=after_room1_token, - ) - ) - - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - leave_response["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined`/`newly_left` because we joined and left - # `room1` before either of the tokens - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_join_leave_multiple_times_during_range_and_after_to_token( - self, - ) -> None: - """ - Join and leave multiple times shouldn't affect rooms from showing up. It just - matters that we had membership in the from/to range. But we're also testing that - joining and leaving after the `to_token` doesn't mess with the results. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - # Join, leave, join back to the room during the from/to range - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Leave and Join the room multiple times after we already have our tokens - leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room1_token, - to_token=after_room1_token, - ) - ) - - # Room should show up because it was newly_left and joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - join_response2["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response1": join_response1["event_id"], - "leave_response1": leave_response1["event_id"], - "join_response2": join_response2["event_id"], - "leave_response2": leave_response2["event_id"], - "join_response3": join_response3["event_id"], - "leave_response3": leave_response3["event_id"], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - # We should *NOT* be `newly_left` because we joined during the token range and - # was still joined at the end of the range - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_join_leave_multiple_times_before_range_and_after_to_token( - self, - ) -> None: - """ - Join and leave multiple times before the from/to range shouldn't affect rooms - from showing up. It just matters that we had membership in the - from/to range. But we're also testing that joining and leaving after the - `to_token` doesn't mess with the results. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - # Join, leave, join back to the room before the from/to range - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Leave and Join the room multiple times after we already have our tokens - leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_room1_token, - to_token=after_room1_token, - ) - ) - - # Room should show up because we were joined before the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - join_response2["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response1": join_response1["event_id"], - "leave_response1": leave_response1["event_id"], - "join_response2": join_response2["event_id"], - "leave_response2": leave_response2["event_id"], - "join_response3": join_response3["event_id"], - "leave_response3": leave_response3["event_id"], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_invite_before_range_and_join_leave_after_to_token( - self, - ) -> None: - """ - Make it look like we joined after the token range but we were invited before the - from/to range so the room should still show up. See condition "1a)" comments in - the `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - - # Invited to the room before the token - invite_response = self.helper.invite( - room_id1, src=user2_id, targ=user1_id, tok=user2_tok - ) - - after_room1_token = self.event_sources.get_current_token() - - # Join and leave the room after we already have our tokens - join_respsonse = self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_room1_token, - to_token=after_room1_token, - ) - ) - - # Room should show up because we were invited before the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - invite_response["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "invite_response": invite_response["event_id"], - "join_respsonse": join_respsonse["event_id"], - "leave_response": leave_response["event_id"], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.INVITE) - # We should *NOT* be `newly_joined` because we were only invited before the - # token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_join_and_display_name_changes_in_token_range( - self, - ) -> None: - """ - Test that we point to the correct membership event within the from/to range even - if there are multiple `join` membership events in a row indicating - `displayname`/`avatar_url` updates. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - # Update the displayname during the token range - displayname_change_during_token_range_response = self.helper.send_state( - room_id1, - event_type=EventTypes.Member, - state_key=user1_id, - body={ - "membership": Membership.JOIN, - "displayname": "displayname during token range", - }, - tok=user1_tok, - ) - - after_room1_token = self.event_sources.get_current_token() - - # Update the displayname after the token range - displayname_change_after_token_range_response = self.helper.send_state( - room_id1, - event_type=EventTypes.Member, - state_key=user1_id, - body={ - "membership": Membership.JOIN, - "displayname": "displayname after token range", - }, - tok=user1_tok, - ) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room1_token, - to_token=after_room1_token, - ) - ) - - # Room should show up because we were joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - displayname_change_during_token_range_response["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response": join_response["event_id"], - "displayname_change_during_token_range_response": displayname_change_during_token_range_response[ - "event_id" - ], - "displayname_change_after_token_range_response": displayname_change_after_token_range_response[ - "event_id" - ], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_display_name_changes_in_token_range( - self, - ) -> None: - """ - Test that we point to the correct membership event within the from/to range even - if there is `displayname`/`avatar_url` updates. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Update the displayname during the token range - displayname_change_during_token_range_response = self.helper.send_state( - room_id1, - event_type=EventTypes.Member, - state_key=user1_id, - body={ - "membership": Membership.JOIN, - "displayname": "displayname during token range", - }, - tok=user1_tok, - ) - - after_change1_token = self.event_sources.get_current_token() - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_room1_token, - to_token=after_change1_token, - ) - ) - - # Room should show up because we were joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - displayname_change_during_token_range_response["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response": join_response["event_id"], - "displayname_change_during_token_range_response": displayname_change_during_token_range_response[ - "event_id" - ], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_display_name_changes_before_and_after_token_range( - self, - ) -> None: - """ - Test that we point to the correct membership event even though there are no - membership events in the from/range but there are `displayname`/`avatar_url` - changes before/after the token range. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - # Update the displayname before the token range - displayname_change_before_token_range_response = self.helper.send_state( - room_id1, - event_type=EventTypes.Member, - state_key=user1_id, - body={ - "membership": Membership.JOIN, - "displayname": "displayname during token range", - }, - tok=user1_tok, - ) - - after_room1_token = self.event_sources.get_current_token() - - # Update the displayname after the token range - displayname_change_after_token_range_response = self.helper.send_state( - room_id1, - event_type=EventTypes.Member, - state_key=user1_id, - body={ - "membership": Membership.JOIN, - "displayname": "displayname after token range", - }, - tok=user1_tok, - ) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_room1_token, - to_token=after_room1_token, - ) - ) - - # Room should show up because we were joined before the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - displayname_change_before_token_range_response["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response": join_response["event_id"], - "displayname_change_before_token_range_response": displayname_change_before_token_range_response[ - "event_id" - ], - "displayname_change_after_token_range_response": displayname_change_after_token_range_response[ - "event_id" - ], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_display_name_changes_leave_after_token_range( - self, - ) -> None: - """ - Test that we point to the correct membership event within the from/to range even - if there are multiple `join` membership events in a row indicating - `displayname`/`avatar_url` updates and we leave after the `to_token`. - - See condition "1a)" comments in the `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - # Update the displayname during the token range - displayname_change_during_token_range_response = self.helper.send_state( - room_id1, - event_type=EventTypes.Member, - state_key=user1_id, - body={ - "membership": Membership.JOIN, - "displayname": "displayname during token range", - }, - tok=user1_tok, - ) - - after_room1_token = self.event_sources.get_current_token() - - # Update the displayname after the token range - displayname_change_after_token_range_response = self.helper.send_state( - room_id1, - event_type=EventTypes.Member, - state_key=user1_id, - body={ - "membership": Membership.JOIN, - "displayname": "displayname after token range", - }, - tok=user1_tok, - ) - - # Leave after the token - self.helper.leave(room_id1, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room1_token, - to_token=after_room1_token, - ) - ) - - # Room should show up because we were joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - displayname_change_during_token_range_response["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response": join_response["event_id"], - "displayname_change_during_token_range_response": displayname_change_during_token_range_response[ - "event_id" - ], - "displayname_change_after_token_range_response": displayname_change_after_token_range_response[ - "event_id" - ], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_display_name_changes_join_after_token_range( - self, - ) -> None: - """ - Test that multiple `join` membership events (after the `to_token`) in a row - indicating `displayname`/`avatar_url` updates doesn't affect the results (we - joined after the token range so it shouldn't show up) - - See condition "1b)" comments in the `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - - after_room1_token = self.event_sources.get_current_token() - - self.helper.join(room_id1, user1_id, tok=user1_tok) - # Update the displayname after the token range - self.helper.send_state( - room_id1, - event_type=EventTypes.Member, - state_key=user1_id, - body={ - "membership": Membership.JOIN, - "displayname": "displayname after token range", - }, - tok=user1_tok, - ) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room1_token, - to_token=after_room1_token, - ) - ) - - # Room shouldn't show up because we joined after the from/to range - self.assertEqual(room_id_results.keys(), set()) - - def test_newly_joined_with_leave_join_in_token_range( - self, - ) -> None: - """ - Test that even though we're joined before the token range, if we leave and join - within the token range, it's still counted as `newly_joined`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Leave and join back during the token range - self.helper.leave(room_id1, user1_id, tok=user1_tok) - join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok) - - after_more_changes_token = self.event_sources.get_current_token() - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=after_room1_token, - to_token=after_more_changes_token, - ) - ) - - # Room should show up because we were joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - join_response2["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should be considered `newly_joined` because there is some non-join event in - # between our latest join event. - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_newly_joined_only_joins_during_token_range( - self, - ) -> None: - """ - Test that a join and more joins caused by display name changes, all during the - token range, still count as `newly_joined`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - # Join, leave, join back to the room before the from/to range - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - # Update the displayname during the token range (looks like another join) - displayname_change_during_token_range_response1 = self.helper.send_state( - room_id1, - event_type=EventTypes.Member, - state_key=user1_id, - body={ - "membership": Membership.JOIN, - "displayname": "displayname during token range", - }, - tok=user1_tok, - ) - # Update the displayname during the token range (looks like another join) - displayname_change_during_token_range_response2 = self.helper.send_state( - room_id1, - event_type=EventTypes.Member, - state_key=user1_id, - body={ - "membership": Membership.JOIN, - "displayname": "displayname during token range", - }, - tok=user1_tok, - ) - - after_room1_token = self.event_sources.get_current_token() - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room1_token, - to_token=after_room1_token, - ) - ) - - # Room should show up because it was newly_left and joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - displayname_change_during_token_range_response2["event_id"], - "Corresponding map to disambiguate the opaque event IDs: " - + str( - { - "join_response1": join_response1["event_id"], - "displayname_change_during_token_range_response1": displayname_change_during_token_range_response1[ - "event_id" - ], - "displayname_change_during_token_range_response2": displayname_change_during_token_range_response2[ - "event_id" - ], - } - ), - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should be `newly_joined` because we first joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - def test_multiple_rooms_are_not_confused( - self, - ) -> None: - """ - Test that multiple rooms are not confused as we fixup the list. This test is - spawning from a real world bug in the code where I was accidentally using - `event.room_id` in one of the fix-up loops but the `event` being referenced was - actually from a different loop. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # We create the room with user2 so the room isn't left with no members when we - # leave and can still re-join. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - - # Invited and left the room before the token - self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - leave_room1_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) - # Invited to room2 - invite_room2_response = self.helper.invite( - room_id2, src=user2_id, targ=user1_id, tok=user2_tok - ) - - before_room3_token = self.event_sources.get_current_token() - - # Invited and left room3 during the from/to range - room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - self.helper.invite(room_id3, src=user2_id, targ=user1_id, tok=user2_tok) - leave_room3_response = self.helper.leave(room_id3, user1_id, tok=user1_tok) - - after_room3_token = self.event_sources.get_current_token() - - # Join and leave the room after we already have our tokens - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.leave(room_id1, user1_id, tok=user1_tok) - # Leave room2 - self.helper.leave(room_id2, user1_id, tok=user1_tok) - # Leave room3 - self.helper.leave(room_id3, user1_id, tok=user1_tok) - - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_room3_token, - to_token=after_room3_token, - ) - ) - - self.assertEqual( - room_id_results.keys(), - { - # Left before the from/to range - room_id1, - # Invited before the from/to range - room_id2, - # `newly_left` during the from/to range - room_id3, - }, - ) - - # Room1 - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - leave_room1_response["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined`/`newly_left` because we were invited and left - # before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - # Room2 - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id2].event_id, - invite_room2_response["event_id"], - ) - self.assertEqual(room_id_results[room_id2].membership, Membership.INVITE) - # We should *NOT* be `newly_joined`/`newly_left` because we were invited before - # the token range - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, False) - - # Room3 - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id3].event_id, - leave_room3_response["event_id"], - ) - self.assertEqual(room_id_results[room_id3].membership, Membership.LEAVE) - # We should be `newly_left` because we were invited and left during - # the token range - self.assertEqual(room_id_results[room_id3].newly_joined, False) - self.assertEqual(room_id_results[room_id3].newly_left, True) - - def test_state_reset(self) -> None: - """ - Test a state reset scenario where the user gets removed from the room (when - there is no corresponding leave event) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # The room where the state reset will happen - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Join another room so we don't hit the short-circuit and return early if they - # have no room membership - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id2, user1_id, tok=user1_tok) - - before_reset_token = self.event_sources.get_current_token() - - # Send another state event to make a position for the state reset to happen at - dummy_state_response = self.helper.send_state( - room_id1, - event_type="foobarbaz", - state_key="", - body={"foo": "bar"}, - tok=user2_tok, - ) - dummy_state_pos = self.get_success( - self.store.get_position_for_event(dummy_state_response["event_id"]) - ) - - # Mock a state reset removing the membership for user1 in the current state - self.get_success( - self.store.db_pool.simple_delete( - table="current_state_events", - keyvalues={ - "room_id": room_id1, - "type": EventTypes.Member, - "state_key": user1_id, - }, - desc="state reset user in current_state_events", - ) - ) - self.get_success( - self.store.db_pool.simple_delete( - table="local_current_membership", - keyvalues={ - "room_id": room_id1, - "user_id": user1_id, - }, - desc="state reset user in local_current_membership", - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - table="current_state_delta_stream", - values={ - "stream_id": dummy_state_pos.stream, - "room_id": room_id1, - "type": EventTypes.Member, - "state_key": user1_id, - "event_id": None, - "prev_event_id": join_response1["event_id"], - "instance_name": dummy_state_pos.instance_name, - }, - desc="state reset user in current_state_delta_stream", - ) - ) - - # Manually bust the cache since we we're just manually messing with the database - # and not causing an actual state reset. - self.store._membership_stream_cache.entity_has_changed( - user1_id, dummy_state_pos.stream - ) - - after_reset_token = self.event_sources.get_current_token() - - # The function under test - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_reset_token, - to_token=after_reset_token, - ) - ) - - # Room1 should show up because it was `newly_left` via state reset during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) - # It should be pointing to no event because we were removed from the room - # without a corresponding leave event - self.assertEqual( - room_id_results[room_id1].event_id, - None, - ) - # State reset caused us to leave the room and there is no corresponding leave event - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - # We should be `newly_left` because we were removed via state reset during the from/to range - self.assertEqual(room_id_results[room_id1].newly_left, True) - - -class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCase): - """ - Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it works with - sharded event stream_writers enabled - """ - - servlets = [ - admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def default_config(self) -> dict: - config = super().default_config() - # Enable sliding sync - config["experimental_features"] = {"msc3575_enabled": True} - - # Enable shared event stream_writers - config["stream_writers"] = {"events": ["worker1", "worker2", "worker3"]} - config["instance_map"] = { - "main": {"host": "testserv", "port": 8765}, - "worker1": {"host": "testserv", "port": 1001}, - "worker2": {"host": "testserv", "port": 1002}, - "worker3": {"host": "testserv", "port": 1003}, - } - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sliding_sync_handler = self.hs.get_sliding_sync_handler() - self.store = self.hs.get_datastores().main - self.event_sources = hs.get_event_sources() - - def _create_room(self, room_id: str, user_id: str, tok: str) -> None: - """ - Create a room with a specific room_id. We use this so that that we have a - consistent room_id across test runs that hashes to the same value and will be - sharded to a known worker in the tests. - """ - - # We control the room ID generation by patching out the - # `_generate_room_id` method - with patch( - "synapse.handlers.room.RoomCreationHandler._generate_room_id" - ) as mock: - mock.side_effect = lambda: room_id - self.helper.create_room_as(user_id, tok=tok) - - def test_sharded_event_persisters(self) -> None: - """ - This test should catch bugs that would come from flawed stream position - (`stream_ordering`) comparisons or making `RoomStreamToken`'s naively. To - compare event positions properly, you need to consider both the `instance_name` - and `stream_ordering` together. - - The test creates three event persister workers and a room that is sharded to - each worker. On worker2, we make the event stream position stuck so that it lags - behind the other workers and we start getting `RoomStreamToken` that have an - `instance_map` component (i.e. q`m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}`). - - We then send some events to advance the stream positions of worker1 and worker3 - but worker2 is lagging behind because it's stuck. We are specifically testing - that `get_room_membership_for_user_at_to_token(from_token=xxx, to_token=xxx)` should work - correctly in these adverse conditions. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - self.make_worker_hs( - "synapse.app.generic_worker", - {"worker_name": "worker1"}, - ) - - worker_hs2 = self.make_worker_hs( - "synapse.app.generic_worker", - {"worker_name": "worker2"}, - ) - - self.make_worker_hs( - "synapse.app.generic_worker", - {"worker_name": "worker3"}, - ) - - # Specially crafted room IDs that get persisted on different workers. - # - # Sharded to worker1 - room_id1 = "!fooo:test" - # Sharded to worker2 - room_id2 = "!bar:test" - # Sharded to worker3 - room_id3 = "!quux:test" - - # Create rooms on the different workers. - self._create_room(room_id1, user2_id, user2_tok) - self._create_room(room_id2, user2_id, user2_tok) - self._create_room(room_id3, user2_id, user2_tok) - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok) - # Leave room2 - leave_room2_response = self.helper.leave(room_id2, user1_id, tok=user1_tok) - join_response3 = self.helper.join(room_id3, user1_id, tok=user1_tok) - # Leave room3 - self.helper.leave(room_id3, user1_id, tok=user1_tok) - - # Ensure that the events were sharded to different workers. - pos1 = self.get_success( - self.store.get_position_for_event(join_response1["event_id"]) - ) - self.assertEqual(pos1.instance_name, "worker1") - pos2 = self.get_success( - self.store.get_position_for_event(join_response2["event_id"]) - ) - self.assertEqual(pos2.instance_name, "worker2") - pos3 = self.get_success( - self.store.get_position_for_event(join_response3["event_id"]) - ) - self.assertEqual(pos3.instance_name, "worker3") - - before_stuck_activity_token = self.event_sources.get_current_token() - - # We now gut wrench into the events stream `MultiWriterIdGenerator` on worker2 to - # mimic it getting stuck persisting an event. This ensures that when we send an - # event on worker1/worker3 we end up in a state where worker2 events stream - # position lags that on worker1/worker3, resulting in a RoomStreamToken with a - # non-empty `instance_map` component. - # - # Worker2's event stream position will not advance until we call `__aexit__` - # again. - worker_store2 = worker_hs2.get_datastores().main - assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator) - actx = worker_store2._stream_id_gen.get_next() - self.get_success(actx.__aenter__()) - - # For room_id1/worker1: leave and join the room to advance the stream position - # and generate membership changes. - self.helper.leave(room_id1, user1_id, tok=user1_tok) - join_room1_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - # For room_id2/worker2: which is currently stuck, join the room. - join_on_worker2_response = self.helper.join(room_id2, user1_id, tok=user1_tok) - # For room_id3/worker3: leave and join the room to advance the stream position - # and generate membership changes. - self.helper.leave(room_id3, user1_id, tok=user1_tok) - join_on_worker3_response = self.helper.join(room_id3, user1_id, tok=user1_tok) - - # Get a token while things are stuck after our activity - stuck_activity_token = self.event_sources.get_current_token() - # Let's make sure we're working with a token that has an `instance_map` - self.assertNotEqual(len(stuck_activity_token.room_key.instance_map), 0) - - # Just double check that the join event on worker2 (that is stuck) happened - # after the position recorded for worker2 in the token but before the max - # position in the token. This is crucial for the behavior we're trying to test. - join_on_worker2_pos = self.get_success( - self.store.get_position_for_event(join_on_worker2_response["event_id"]) - ) - # Ensure the join technially came after our token - self.assertGreater( - join_on_worker2_pos.stream, - stuck_activity_token.room_key.get_stream_pos_for_instance("worker2"), - ) - # But less than the max stream position of some other worker - self.assertLess( - join_on_worker2_pos.stream, - # max - stuck_activity_token.room_key.get_max_stream_pos(), - ) - - # Just double check that the join event on worker3 happened after the min stream - # value in the token but still within the position recorded for worker3. This is - # crucial for the behavior we're trying to test. - join_on_worker3_pos = self.get_success( - self.store.get_position_for_event(join_on_worker3_response["event_id"]) - ) - # Ensure the join came after the min but still encapsulated by the token - self.assertGreaterEqual( - join_on_worker3_pos.stream, - # min - stuck_activity_token.room_key.stream, - ) - self.assertLessEqual( - join_on_worker3_pos.stream, - stuck_activity_token.room_key.get_stream_pos_for_instance("worker3"), - ) - - # We finish the fake persisting an event we started above and advance worker2's - # event stream position (unstuck worker2). - self.get_success(actx.__aexit__(None, None, None)) - - # The function under test - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), - from_token=before_stuck_activity_token, - to_token=stuck_activity_token, - ) - ) - - self.assertEqual( - room_id_results.keys(), - { - room_id1, - room_id2, - room_id3, - }, - ) - - # Room1 - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - join_room1_response["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - # Room2 - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id2].event_id, - leave_room2_response["event_id"], - ) - self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) - # room_id2 should *NOT* be considered `newly_left` because we left before the - # from/to range and the join event during the range happened while worker2 was - # stuck. This means that from the perspective of the master, where the - # `stuck_activity_token` is generated, the stream position for worker2 wasn't - # advanced to the join yet. Looking at the `instance_map`, the join technically - # comes after `stuck_activity_token`. - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, False) - - # Room3 - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id3].event_id, - join_on_worker3_response["event_id"], - ) - self.assertEqual(room_id_results[room_id3].membership, Membership.JOIN) - # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id3].newly_joined, True) - self.assertEqual(room_id_results[room_id3].newly_left, False) - - -class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): - """ - Tests Sliding Sync handler `filter_rooms_relevant_for_sync()` to make sure it returns - the correct list of rooms IDs. - """ - - servlets = [ - admin.register_servlets, - knock.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def default_config(self) -> JsonDict: - config = super().default_config() - # Enable sliding sync - config["experimental_features"] = {"msc3575_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sliding_sync_handler = self.hs.get_sliding_sync_handler() - self.store = self.hs.get_datastores().main - self.event_sources = hs.get_event_sources() - self.storage_controllers = hs.get_storage_controllers() - - def _get_sync_room_ids_for_user( - self, - user: UserID, - to_token: StreamToken, - from_token: Optional[StreamToken], - ) -> Dict[str, _RoomMembershipForUser]: - """ - Get the rooms the user should be syncing with - """ - room_membership_for_user_map = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - user=user, - from_token=from_token, - to_token=to_token, - ) - ) - filtered_sync_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms_relevant_for_sync( - user=user, - room_membership_for_user_map=room_membership_for_user_map, - ) - ) - - return filtered_sync_room_map - - def test_no_rooms(self) -> None: - """ - Test when the user has never joined any rooms before - """ - user1_id = self.register_user("user1", "pass") - # user1_tok = self.login(user1_id, "pass") - - now_token = self.event_sources.get_current_token() - - room_id_results = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=now_token, - to_token=now_token, - ) - - self.assertEqual(room_id_results.keys(), set()) - - def test_basic_rooms(self) -> None: - """ - Test that rooms that the user is joined to, invited to, banned from, and knocked - on show up. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room_token = self.event_sources.get_current_token() - - join_room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response = self.helper.join(join_room_id, user1_id, tok=user1_tok) - - # Setup the invited room (user2 invites user1 to the room) - invited_room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - invite_response = self.helper.invite( - invited_room_id, targ=user1_id, tok=user2_tok - ) - - # Setup the ban room (user2 bans user1 from the room) - ban_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok, is_public=True - ) - self.helper.join(ban_room_id, user1_id, tok=user1_tok) - ban_response = self.helper.ban( - ban_room_id, src=user2_id, targ=user1_id, tok=user2_tok - ) - - # Setup the knock room (user1 knocks on the room) - knock_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok, room_version=RoomVersions.V7.identifier - ) - self.helper.send_state( - knock_room_id, - EventTypes.JoinRules, - {"join_rule": JoinRules.KNOCK}, - tok=user2_tok, - ) - # User1 knocks on the room - knock_channel = self.make_request( - "POST", - "/_matrix/client/r0/knock/%s" % (knock_room_id,), - b"{}", - user1_tok, - ) - self.assertEqual(knock_channel.code, 200, knock_channel.result) - knock_room_membership_state_event = self.get_success( - self.storage_controllers.state.get_current_state_event( - knock_room_id, EventTypes.Member, user1_id - ) - ) - assert knock_room_membership_state_event is not None - - after_room_token = self.event_sources.get_current_token() - - room_id_results = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=before_room_token, - to_token=after_room_token, - ) - - # Ensure that the invited, ban, and knock rooms show up - self.assertEqual( - room_id_results.keys(), - { - join_room_id, - invited_room_id, - ban_room_id, - knock_room_id, - }, - ) - # It should be pointing to the the respective membership event (latest - # membership event in the from/to range) - self.assertEqual( - room_id_results[join_room_id].event_id, - join_response["event_id"], - ) - self.assertEqual(room_id_results[join_room_id].membership, Membership.JOIN) - self.assertEqual(room_id_results[join_room_id].newly_joined, True) - self.assertEqual(room_id_results[join_room_id].newly_left, False) - - self.assertEqual( - room_id_results[invited_room_id].event_id, - invite_response["event_id"], - ) - self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE) - self.assertEqual(room_id_results[invited_room_id].newly_joined, False) - self.assertEqual(room_id_results[invited_room_id].newly_left, False) - - self.assertEqual( - room_id_results[ban_room_id].event_id, - ban_response["event_id"], - ) - self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN) - self.assertEqual(room_id_results[ban_room_id].newly_joined, False) - self.assertEqual(room_id_results[ban_room_id].newly_left, False) - - self.assertEqual( - room_id_results[knock_room_id].event_id, - knock_room_membership_state_event.event_id, - ) - self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK) - self.assertEqual(room_id_results[knock_room_id].newly_joined, False) - self.assertEqual(room_id_results[knock_room_id].newly_left, False) - - def test_only_newly_left_rooms_show_up(self) -> None: - """ - Test that `newly_left` rooms still show up in the sync response but rooms that - were left before the `from_token` don't show up. See condition "2)" comments in - the `get_room_membership_for_user_at_to_token()` method. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Leave before we calculate the `from_token` - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.leave(room_id1, user1_id, tok=user1_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Leave during the from_token/to_token range (newly_left) - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) - _leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok) - - after_room2_token = self.event_sources.get_current_token() - - room_id_results = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=after_room1_token, - to_token=after_room2_token, - ) - - # Only the `newly_left` room should show up - self.assertEqual(room_id_results.keys(), {room_id2}) - self.assertEqual( - room_id_results[room_id2].event_id, - _leave_response2["event_id"], - ) - # We should *NOT* be `newly_joined` because we are instead `newly_left` - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, True) - - def test_get_kicked_room(self) -> None: - """ - Test that a room that the user was kicked from still shows up. When the user - comes back to their client, they should see that they were kicked. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Setup the kick room (user2 kicks user1 from the room) - kick_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok, is_public=True - ) - self.helper.join(kick_room_id, user1_id, tok=user1_tok) - # Kick user1 from the room - kick_response = self.helper.change_membership( - room=kick_room_id, - src=user2_id, - targ=user1_id, - tok=user2_tok, - membership=Membership.LEAVE, - extra_data={ - "reason": "Bad manners", - }, - ) - - after_kick_token = self.event_sources.get_current_token() - - room_id_results = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=after_kick_token, - to_token=after_kick_token, - ) - - # The kicked room should show up - self.assertEqual(room_id_results.keys(), {kick_room_id}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[kick_room_id].event_id, - kick_response["event_id"], - ) - self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE) - self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id) - # We should *NOT* be `newly_joined` because we were not joined at the the time - # of the `to_token`. - self.assertEqual(room_id_results[kick_room_id].newly_joined, False) - self.assertEqual(room_id_results[kick_room_id].newly_left, False) - - def test_state_reset(self) -> None: - """ - Test a state reset scenario where the user gets removed from the room (when - there is no corresponding leave event) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # The room where the state reset will happen - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Join another room so we don't hit the short-circuit and return early if they - # have no room membership - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id2, user1_id, tok=user1_tok) - - before_reset_token = self.event_sources.get_current_token() - - # Send another state event to make a position for the state reset to happen at - dummy_state_response = self.helper.send_state( - room_id1, - event_type="foobarbaz", - state_key="", - body={"foo": "bar"}, - tok=user2_tok, - ) - dummy_state_pos = self.get_success( - self.store.get_position_for_event(dummy_state_response["event_id"]) - ) - - # Mock a state reset removing the membership for user1 in the current state - self.get_success( - self.store.db_pool.simple_delete( - table="current_state_events", - keyvalues={ - "room_id": room_id1, - "type": EventTypes.Member, - "state_key": user1_id, - }, - desc="state reset user in current_state_events", - ) - ) - self.get_success( - self.store.db_pool.simple_delete( - table="local_current_membership", - keyvalues={ - "room_id": room_id1, - "user_id": user1_id, - }, - desc="state reset user in local_current_membership", - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - table="current_state_delta_stream", - values={ - "stream_id": dummy_state_pos.stream, - "room_id": room_id1, - "type": EventTypes.Member, - "state_key": user1_id, - "event_id": None, - "prev_event_id": join_response1["event_id"], - "instance_name": dummy_state_pos.instance_name, - }, - desc="state reset user in current_state_delta_stream", - ) - ) - - # Manually bust the cache since we we're just manually messing with the database - # and not causing an actual state reset. - self.store._membership_stream_cache.entity_has_changed( - user1_id, dummy_state_pos.stream - ) - - after_reset_token = self.event_sources.get_current_token() - - # The function under test - room_id_results = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=before_reset_token, - to_token=after_reset_token, - ) - - # Room1 should show up because it was `newly_left` via state reset during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) - # It should be pointing to no event because we were removed from the room - # without a corresponding leave event - self.assertEqual( - room_id_results[room_id1].event_id, - None, - ) - # State reset caused us to leave the room and there is no corresponding leave event - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - # We should be `newly_left` because we were removed via state reset during the from/to range - self.assertEqual(room_id_results[room_id1].newly_left, True) - - -class FilterRoomsTestCase(HomeserverTestCase): - """ - Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms - correctly. - """ - - servlets = [ - admin.register_servlets, - knock.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def default_config(self) -> JsonDict: - config = super().default_config() - # Enable sliding sync - config["experimental_features"] = {"msc3575_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sliding_sync_handler = self.hs.get_sliding_sync_handler() - self.store = self.hs.get_datastores().main - self.event_sources = hs.get_event_sources() - - def _get_sync_room_ids_for_user( - self, - user: UserID, - to_token: StreamToken, - from_token: Optional[StreamToken], - ) -> Dict[str, _RoomMembershipForUser]: - """ - Get the rooms the user should be syncing with - """ - room_membership_for_user_map = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - user=user, - from_token=from_token, - to_token=to_token, - ) - ) - filtered_sync_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms_relevant_for_sync( - user=user, - room_membership_for_user_map=room_membership_for_user_map, - ) - ) - - return filtered_sync_room_map - - def _create_dm_room( - self, - inviter_user_id: str, - inviter_tok: str, - invitee_user_id: str, - invitee_tok: str, - ) -> str: - """ - Helper to create a DM room as the "inviter" and invite the "invitee" user to the room. The - "invitee" user also will join the room. The `m.direct` account data will be set - for both users. - """ - - # Create a room and send an invite the other user - room_id = self.helper.create_room_as( - inviter_user_id, - is_public=False, - tok=inviter_tok, - ) - self.helper.invite( - room_id, - src=inviter_user_id, - targ=invitee_user_id, - tok=inviter_tok, - extra_data={"is_direct": True}, - ) - # Person that was invited joins the room - self.helper.join(room_id, invitee_user_id, tok=invitee_tok) - - # Mimic the client setting the room as a direct message in the global account - # data - self.get_success( - self.store.add_account_data_for_user( - invitee_user_id, - AccountDataTypes.DIRECT, - {inviter_user_id: [room_id]}, - ) - ) - self.get_success( - self.store.add_account_data_for_user( - inviter_user_id, - AccountDataTypes.DIRECT, - {invitee_user_id: [room_id]}, - ) - ) - - return room_id - - _remote_invite_count: int = 0 - - def _create_remote_invite_room_for_user( - self, - invitee_user_id: str, - unsigned_invite_room_state: Optional[List[StrippedStateEvent]], - ) -> str: - """ - Create a fake invite for a remote room and persist it. - - We don't have any state for these kind of rooms and can only rely on the - stripped state included in the unsigned portion of the invite event to identify - the room. - - Args: - invitee_user_id: The person being invited - unsigned_invite_room_state: List of stripped state events to assist the - receiver in identifying the room. - - Returns: - The room ID of the remote invite room - """ - invite_room_id = f"!test_room{self._remote_invite_count}:remote_server" - - invite_event_dict = { - "room_id": invite_room_id, - "sender": "@inviter:remote_server", - "state_key": invitee_user_id, - "depth": 1, - "origin_server_ts": 1, - "type": EventTypes.Member, - "content": {"membership": Membership.INVITE}, - "auth_events": [], - "prev_events": [], - } - if unsigned_invite_room_state is not None: - serialized_stripped_state_events = [] - for stripped_event in unsigned_invite_room_state: - serialized_stripped_state_events.append( - { - "type": stripped_event.type, - "state_key": stripped_event.state_key, - "sender": stripped_event.sender, - "content": stripped_event.content, - } - ) - - invite_event_dict["unsigned"] = { - "invite_room_state": serialized_stripped_state_events - } - - invite_event = make_event_from_dict( - invite_event_dict, - room_version=RoomVersions.V10, - ) - invite_event.internal_metadata.outlier = True - invite_event.internal_metadata.out_of_band_membership = True - - self.get_success( - self.store.maybe_store_room_on_outlier_membership( - room_id=invite_room_id, room_version=invite_event.room_version - ) - ) - context = EventContext.for_outlier(self.hs.get_storage_controllers()) - persist_controller = self.hs.get_storage_controllers().persistence - assert persist_controller is not None - self.get_success(persist_controller.persist_event(invite_event, context)) - - self._remote_invite_count += 1 - - return invite_room_id - - def test_filter_dm_rooms(self) -> None: - """ - Test `filter.is_dm` for DM rooms - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Create a normal room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a DM room - dm_room_id = self._create_dm_room( - inviter_user_id=user1_id, - inviter_tok=user1_tok, - invitee_user_id=user2_id, - invitee_tok=user2_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_dm=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_dm=True, - ), - after_rooms_token, - ) - ) - - self.assertEqual(truthy_filtered_room_map.keys(), {dm_room_id}) - - # Try with `is_dm=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_dm=False, - ), - after_rooms_token, - ) - ) - - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_rooms(self) -> None: - """ - Test `filter.is_encrypted` for encrypted rooms - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id}) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_server_left_room(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` against a room that everyone has left. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - before_rooms_token = self.event_sources.get_current_token() - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - # Leave the room - self.helper.leave(room_id, user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - # Leave the room - self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - # We're using a `from_token` so that the room is considered `newly_left` and - # appears in our list of relevant sync rooms - from_token=before_rooms_token, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id}) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_server_left_room2(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` against a room that everyone has - left. - - There is still someone local who is invited to the rooms but that doesn't affect - whether the server is participating in the room (users need to be joined). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - _user2_tok = self.login(user2_id, "pass") - - before_rooms_token = self.event_sources.get_current_token() - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - # Invite user2 - self.helper.invite(room_id, targ=user2_id, tok=user1_tok) - # User1 leaves the room - self.helper.leave(room_id, user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - # Invite user2 - self.helper.invite(encrypted_room_id, targ=user2_id, tok=user1_tok) - # User1 leaves the room - self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - # We're using a `from_token` so that the room is considered `newly_left` and - # appears in our list of relevant sync rooms - from_token=before_rooms_token, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id}) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_after_we_left(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` against a room that was encrypted - after we left the room (make sure we don't just use the current state) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_rooms_token = self.event_sources.get_current_token() - - # Create an unencrypted room - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - # Leave the room - self.helper.join(room_id, user1_id, tok=user1_tok) - self.helper.leave(room_id, user1_id, tok=user1_tok) - - # Create a room that will be encrypted - encrypted_after_we_left_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok - ) - # Leave the room - self.helper.join(encrypted_after_we_left_room_id, user1_id, tok=user1_tok) - self.helper.leave(encrypted_after_we_left_room_id, user1_id, tok=user1_tok) - - # Encrypt the room after we've left - self.helper.send_state( - encrypted_after_we_left_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user2_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - # We're using a `from_token` so that the room is considered `newly_left` and - # appears in our list of relevant sync rooms - from_token=before_rooms_token, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - # Even though we left the room before it was encrypted, we still see it because - # someone else on our server is still participating in the room and we "leak" - # the current state to the left user. But we consider the room encryption status - # to not be a secret given it's often set at the start of the room and it's one - # of the stripped state events that is normally handed out. - self.assertEqual( - truthy_filtered_room_map.keys(), {encrypted_after_we_left_room_id} - ) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - # Even though we left the room before it was encrypted... (see comment above) - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_with_remote_invite_room_no_stripped_state(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` filter against a remote invite - room without any `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room without any `unsigned.invite_room_state` - _remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, None - ) - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear because we can't figure out whether - # it is encrypted or not (no stripped state, `unsigned.invite_room_state`). - self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id}) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear because we can't figure out whether - # it is encrypted or not (no stripped state, `unsigned.invite_room_state`). - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_with_remote_invite_encrypted_room(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` filter against a remote invite - encrypted room with some `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room with some `unsigned.invite_room_state` - # indicating that the room is encrypted. - remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, - [ - StrippedStateEvent( - type=EventTypes.Create, - state_key="", - sender="@inviter:remote_server", - content={ - EventContentFields.ROOM_CREATOR: "@inviter:remote_server", - EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier, - }, - ), - StrippedStateEvent( - type=EventTypes.RoomEncryption, - state_key="", - sender="@inviter:remote_server", - content={ - EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2", - }, - ), - ], - ) - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should appear here because it is encrypted - # according to the stripped state - self.assertEqual( - truthy_filtered_room_map.keys(), {encrypted_room_id, remote_invite_room_id} - ) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear here because it is encrypted - # according to the stripped state - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_with_remote_invite_unencrypted_room(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` filter against a remote invite - unencrypted room with some `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room with some `unsigned.invite_room_state` - # but don't set any room encryption event. - remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, - [ - StrippedStateEvent( - type=EventTypes.Create, - state_key="", - sender="@inviter:remote_server", - content={ - EventContentFields.ROOM_CREATOR: "@inviter:remote_server", - EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier, - }, - ), - # No room encryption event - ], - ) - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear here because it is unencrypted - # according to the stripped state - self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id}) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should appear because it is unencrypted according to - # the stripped state - self.assertEqual( - falsy_filtered_room_map.keys(), {room_id, remote_invite_room_id} - ) - - def test_filter_invite_rooms(self) -> None: - """ - Test `filter.is_invite` for rooms that the user has been invited to - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Create a normal room - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id, user1_id, tok=user1_tok) - - # Create a room that user1 is invited to - invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_invite=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_invite=True, - ), - after_rooms_token, - ) - ) - - self.assertEqual(truthy_filtered_room_map.keys(), {invite_room_id}) - - # Try with `is_invite=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_invite=False, - ), - after_rooms_token, - ) - ) - - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_room_types(self) -> None: - """ - Test `filter.room_types` for different room types - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - - # Create an arbitrarily typed room - foo_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": { - EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz" - } - }, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {space_room_id}) - - # Try finding normal rooms and spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - room_types=[None, RoomTypes.SPACE] - ), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {room_id, space_room_id}) - - # Try finding an arbitrary room type - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - room_types=["org.matrix.foobarbaz"] - ), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {foo_room_id}) - - def test_filter_not_room_types(self) -> None: - """ - Test `filter.not_room_types` for different room types - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - - # Create an arbitrarily typed room - foo_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": { - EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz" - } - }, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try finding *NOT* normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(not_room_types=[None]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {space_room_id, foo_room_id}) - - # Try finding *NOT* spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - not_room_types=[RoomTypes.SPACE] - ), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {room_id, foo_room_id}) - - # Try finding *NOT* normal rooms or spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - not_room_types=[None, RoomTypes.SPACE] - ), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {foo_room_id}) - - # Test how it behaves when we have both `room_types` and `not_room_types`. - # `not_room_types` should win. - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - room_types=[None], not_room_types=[None] - ), - after_rooms_token, - ) - ) - - # Nothing matches because nothing is both a normal room and not a normal room - self.assertEqual(filtered_room_map.keys(), set()) - - # Test how it behaves when we have both `room_types` and `not_room_types`. - # `not_room_types` should win. - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - room_types=[None, RoomTypes.SPACE], not_room_types=[None] - ), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {space_room_id}) - - def test_filter_room_types_server_left_room(self) -> None: - """ - Test that we can apply a `filter.room_types` against a room that everyone has left. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - before_rooms_token = self.event_sources.get_current_token() - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - # Leave the room - self.helper.leave(room_id, user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - # Leave the room - self.helper.leave(space_room_id, user1_id, tok=user1_tok) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - # We're using a `from_token` so that the room is considered `newly_left` and - # appears in our list of relevant sync rooms - from_token=before_rooms_token, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {space_room_id}) - - def test_filter_room_types_server_left_room2(self) -> None: - """ - Test that we can apply a `filter.room_types` against a room that everyone has left. - - There is still someone local who is invited to the rooms but that doesn't affect - whether the server is participating in the room (users need to be joined). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - _user2_tok = self.login(user2_id, "pass") - - before_rooms_token = self.event_sources.get_current_token() - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - # Invite user2 - self.helper.invite(room_id, targ=user2_id, tok=user1_tok) - # User1 leaves the room - self.helper.leave(room_id, user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - # Invite user2 - self.helper.invite(space_room_id, targ=user2_id, tok=user1_tok) - # User1 leaves the room - self.helper.leave(space_room_id, user1_id, tok=user1_tok) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - # We're using a `from_token` so that the room is considered `newly_left` and - # appears in our list of relevant sync rooms - from_token=before_rooms_token, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {space_room_id}) - - def test_filter_room_types_with_remote_invite_room_no_stripped_state(self) -> None: - """ - Test that we can apply a `filter.room_types` filter against a remote invite - room without any `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room without any `unsigned.invite_room_state` - _remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, None - ) - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear because we can't figure out what - # room type it is (no stripped state, `unsigned.invite_room_state`) - self.assertEqual(filtered_room_map.keys(), {room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear because we can't figure out what - # room type it is (no stripped state, `unsigned.invite_room_state`) - self.assertEqual(filtered_room_map.keys(), {space_room_id}) - - def test_filter_room_types_with_remote_invite_space(self) -> None: - """ - Test that we can apply a `filter.room_types` filter against a remote invite - to a space room with some `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room with some `unsigned.invite_room_state` indicating - # that it is a space room - remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, - [ - StrippedStateEvent( - type=EventTypes.Create, - state_key="", - sender="@inviter:remote_server", - content={ - EventContentFields.ROOM_CREATOR: "@inviter:remote_server", - EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier, - # Specify that it is a space room - EventContentFields.ROOM_TYPE: RoomTypes.SPACE, - }, - ), - ], - ) - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear here because it is a space room - # according to the stripped state - self.assertEqual(filtered_room_map.keys(), {room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should appear here because it is a space room - # according to the stripped state - self.assertEqual( - filtered_room_map.keys(), {space_room_id, remote_invite_room_id} - ) - - def test_filter_room_types_with_remote_invite_normal_room(self) -> None: - """ - Test that we can apply a `filter.room_types` filter against a remote invite - to a normal room with some `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room with some `unsigned.invite_room_state` - # but the create event does not specify a room type (normal room) - remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, - [ - StrippedStateEvent( - type=EventTypes.Create, - state_key="", - sender="@inviter:remote_server", - content={ - EventContentFields.ROOM_CREATOR: "@inviter:remote_server", - EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier, - # No room type means this is a normal room - }, - ), - ], - ) - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should appear here because it is a normal room - # according to the stripped state (no room type) - self.assertEqual(filtered_room_map.keys(), {room_id, remote_invite_room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear here because it is a normal room - # according to the stripped state (no room type) - self.assertEqual(filtered_room_map.keys(), {space_room_id}) - - -class SortRoomsTestCase(HomeserverTestCase): - """ - Tests Sliding Sync handler `sort_rooms()` to make sure it sorts/orders rooms - correctly. - """ - - servlets = [ - admin.register_servlets, - knock.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def default_config(self) -> JsonDict: - config = super().default_config() - # Enable sliding sync - config["experimental_features"] = {"msc3575_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sliding_sync_handler = self.hs.get_sliding_sync_handler() - self.store = self.hs.get_datastores().main - self.event_sources = hs.get_event_sources() - - def _get_sync_room_ids_for_user( - self, - user: UserID, - to_token: StreamToken, - from_token: Optional[StreamToken], - ) -> Dict[str, _RoomMembershipForUser]: - """ - Get the rooms the user should be syncing with - """ - room_membership_for_user_map = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - user=user, - from_token=from_token, - to_token=to_token, - ) - ) - filtered_sync_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms_relevant_for_sync( - user=user, - room_membership_for_user_map=room_membership_for_user_map, - ) - ) - - return filtered_sync_room_map - - def test_sort_activity_basic(self) -> None: - """ - Rooms with newer activity are sorted first. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as( - user1_id, - tok=user1_tok, - ) - room_id2 = self.helper.create_room_as( - user1_id, - tok=user1_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Sort the rooms (what we're testing) - sorted_sync_rooms = self.get_success( - self.sliding_sync_handler.sort_rooms( - sync_room_map=sync_room_map, - to_token=after_rooms_token, - ) - ) - - self.assertEqual( - [room_membership.room_id for room_membership in sorted_sync_rooms], - [room_id2, room_id1], - ) - - @parameterized.expand( - [ - (Membership.LEAVE,), - (Membership.INVITE,), - (Membership.KNOCK,), - (Membership.BAN,), - ] - ) - def test_activity_after_xxx(self, room1_membership: str) -> None: - """ - When someone has left/been invited/knocked/been banned from a room, they - shouldn't take anything into account after that membership event. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_rooms_token = self.event_sources.get_current_token() - - # Create the rooms as user2 so we can have user1 with a clean slate to work from - # and join in whatever order we need for the tests. - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - # If we're testing knocks, set the room to knock - if room1_membership == Membership.KNOCK: - self.helper.send_state( - room_id1, - EventTypes.JoinRules, - {"join_rule": JoinRules.KNOCK}, - tok=user2_tok, - ) - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - - # Here is the activity with user1 that will determine the sort of the rooms - # (room2, room1, room3) - self.helper.join(room_id3, user1_id, tok=user1_tok) - if room1_membership == Membership.LEAVE: - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.leave(room_id1, user1_id, tok=user1_tok) - elif room1_membership == Membership.INVITE: - self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - elif room1_membership == Membership.KNOCK: - self.helper.knock(room_id1, user1_id, tok=user1_tok) - elif room1_membership == Membership.BAN: - self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - self.helper.join(room_id2, user1_id, tok=user1_tok) - - # Activity before the token but the user is only been xxx to this room so it - # shouldn't be taken into account - self.helper.send(room_id1, "activity in room1", tok=user2_tok) - - after_rooms_token = self.event_sources.get_current_token() - - # Activity after the token. Just make it in a different order than what we - # expect to make sure we're not taking the activity after the token into - # account. - self.helper.send(room_id1, "activity in room1", tok=user2_tok) - self.helper.send(room_id2, "activity in room2", tok=user2_tok) - self.helper.send(room_id3, "activity in room3", tok=user2_tok) - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=before_rooms_token, - to_token=after_rooms_token, - ) - - # Sort the rooms (what we're testing) - sorted_sync_rooms = self.get_success( - self.sliding_sync_handler.sort_rooms( - sync_room_map=sync_room_map, - to_token=after_rooms_token, - ) - ) - - self.assertEqual( - [room_membership.room_id for room_membership in sorted_sync_rooms], - [room_id2, room_id1, room_id3], - "Corresponding map to disambiguate the opaque room IDs: " - + str( - { - "room_id1": room_id1, - "room_id2": room_id2, - "room_id3": room_id3, - } - ), - ) - - def test_default_bump_event_types(self) -> None: - """ - Test that we only consider the *latest* event in the room when sorting (not - `bump_event_types`). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as( - user1_id, - tok=user1_tok, - ) - message_response = self.helper.send(room_id1, "message in room1", tok=user1_tok) - room_id2 = self.helper.create_room_as( - user1_id, - tok=user1_tok, - ) - self.helper.send(room_id2, "message in room2", tok=user1_tok) - - # Send a reaction in room1 which isn't in `DEFAULT_BUMP_EVENT_TYPES` but we only - # care about sorting by the *latest* event in the room. - self.helper.send_event( - room_id1, - type=EventTypes.Reaction, - content={ - "m.relates_to": { - "event_id": message_response["event_id"], - "key": "👍", - "rel_type": "m.annotation", - } - }, - tok=user1_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Sort the rooms (what we're testing) - sorted_sync_rooms = self.get_success( - self.sliding_sync_handler.sort_rooms( - sync_room_map=sync_room_map, - to_token=after_rooms_token, - ) - ) - - self.assertEqual( - [room_membership.room_id for room_membership in sorted_sync_rooms], - # room1 sorts before room2 because it has the latest event (the reaction). - # We only care about the *latest* event in the room. - [room_id1, room_id2], - ) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index fa55f76916..37904926e3 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py
@@ -17,46 +17,25 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Collection, ContextManager, List, Optional +from typing import Optional from unittest.mock import AsyncMock, Mock, patch -from parameterized import parameterized - -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules +from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError -from synapse.api.filtering import FilterCollection, Filtering -from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.events import EventBase -from synapse.events.snapshot import EventContext -from synapse.federation.federation_base import event_from_pdu_json -from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVersion +from synapse.api.filtering import Filtering +from synapse.api.room_versions import RoomVersions +from synapse.handlers.sync import SyncConfig, SyncResult from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer -from synapse.types import ( - JsonDict, - MultiWriterStreamToken, - RoomStreamToken, - StreamKeyType, - UserID, - create_requester, -) +from synapse.types import UserID, create_requester from synapse.util import Clock import tests.unittest import tests.utils -_request_key = 0 - - -def generate_request_key() -> SyncRequestKey: - global _request_key - _request_key += 1 - return ("request_key", _request_key) - class SyncTestCase(tests.unittest.HomeserverTestCase): """Tests Sync Handler.""" @@ -89,23 +68,13 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Check that the happy case does not throw errors self.get_success(self.store.upsert_monthly_active_user(user_id1)) self.get_success( - self.sync_handler.wait_for_sync_for_user( - requester, - sync_config, - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ) + self.sync_handler.wait_for_sync_for_user(requester, sync_config) ) # Test that global lock works self.auth_blocking._hs_disabled = True e = self.get_failure( - self.sync_handler.wait_for_sync_for_user( - requester, - sync_config, - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ), + self.sync_handler.wait_for_sync_for_user(requester, sync_config), ResourceLimitError, ) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) @@ -116,12 +85,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): requester = create_requester(user_id2) e = self.get_failure( - self.sync_handler.wait_for_sync_for_user( - requester, - sync_config, - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ), + self.sync_handler.wait_for_sync_for_user(requester, sync_config), ResourceLimitError, ) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) @@ -140,10 +104,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): requester = create_requester(user) initial_result = self.get_success( self.sync_handler.wait_for_sync_for_user( - requester, - sync_config=generate_sync_config(user, device_id="dev"), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), + requester, sync_config=generate_sync_config(user, device_id="dev") ) ) @@ -174,10 +135,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # The rooms should appear in the sync response. result = self.get_success( self.sync_handler.wait_for_sync_for_user( - requester, - sync_config=generate_sync_config(user), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), + requester, sync_config=generate_sync_config(user) ) ) self.assertIn(joined_room, [r.room_id for r in result.joined]) @@ -189,8 +147,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user( requester, sync_config=generate_sync_config(user, device_id="dev"), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), since_token=initial_result.next_batch, ) ) @@ -210,8 +166,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) # Blow away caches (supported room versions can only change due to a restart). + self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() self.store.get_rooms_for_user.invalidate_all() - self.store._get_rooms_for_local_user_where_membership_is_inner.invalidate_all() self.store._get_event_cache.clear() self.store._event_ref.clear() @@ -219,10 +175,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Get a new request key. result = self.get_success( self.sync_handler.wait_for_sync_for_user( - requester, - sync_config=generate_sync_config(user), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), + requester, sync_config=generate_sync_config(user) ) ) self.assertNotIn(joined_room, [r.room_id for r in result.joined]) @@ -234,8 +187,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user( requester, sync_config=generate_sync_config(user, device_id="dev"), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), since_token=initial_result.next_batch, ) ) @@ -275,10 +226,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Do a sync as Alice to get the latest event in the room. alice_sync_result: SyncResult = self.get_success( self.sync_handler.wait_for_sync_for_user( - create_requester(owner), - generate_sync_config(owner), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), + create_requester(owner), generate_sync_config(owner) ) ) self.assertEqual(len(alice_sync_result.joined), 1) @@ -298,12 +246,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): eve_requester = create_requester(eve) eve_sync_config = generate_sync_config(eve) eve_sync_after_ban: SyncResult = self.get_success( - self.sync_handler.wait_for_sync_for_user( - eve_requester, - eve_sync_config, - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ) + self.sync_handler.wait_for_sync_for_user(eve_requester, eve_sync_config) ) # Sanity check this sync result. We shouldn't be joined to the room. @@ -312,7 +255,13 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Eve tries to join the room. We monkey patch the internal logic which selects # the prev_events used when creating the join event, such that the ban does not # precede the join. - with self._patch_get_latest_events([last_room_creation_event_id]): + mocked_get_prev_events = patch.object( + self.hs.get_datastores().main, + "get_prev_events_for_room", + new_callable=AsyncMock, + return_value=[last_room_creation_event_id], + ) + with mocked_get_prev_events: self.helper.join(room_id, eve, tok=eve_token) # Eve makes a second, incremental sync. @@ -320,8 +269,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user( eve_requester, eve_sync_config, - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), since_token=eve_sync_after_ban.next_batch, ) ) @@ -333,748 +280,25 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user( eve_requester, eve_sync_config, - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), since_token=None, ) ) self.assertEqual(eve_initial_sync_after_join.joined, []) - def test_state_includes_changes_on_forks(self) -> None: - """State changes that happen on a fork of the DAG must be included in `state` - - Given the following DAG: - - E1 - ↗ ↖ - | S2 - | ↑ - --|------|---- - | | - E3 | - ↖ / - E4 - - ... and a filter that means we only return 2 events, represented by the dashed - horizontal line: `S2` must be included in the `state` section. - """ - alice = self.register_user("alice", "password") - alice_tok = self.login(alice, "password") - alice_requester = create_requester(alice) - room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok) - - # Do an initial sync as Alice to get a known starting point. - initial_sync_result = self.get_success( - self.sync_handler.wait_for_sync_for_user( - alice_requester, - generate_sync_config(alice), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ) - ) - last_room_creation_event_id = ( - initial_sync_result.joined[0].timeline.events[-1].event_id - ) - - # Send a state event, and a regular event, both using the same prev ID - with self._patch_get_latest_events([last_room_creation_event_id]): - s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[ - "event_id" - ] - e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"] - - # Send a final event, joining the two branches of the dag - e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"] - - # do an incremental sync, with a filter that will ensure we only get two of - # the three new events. - incremental_sync = self.get_success( - self.sync_handler.wait_for_sync_for_user( - alice_requester, - generate_sync_config( - alice, - filter_collection=FilterCollection( - self.hs, {"room": {"timeline": {"limit": 2}}} - ), - ), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - since_token=initial_sync_result.next_batch, - ) - ) - - # The state event should appear in the 'state' section of the response. - room_sync = incremental_sync.joined[0] - self.assertEqual(room_sync.room_id, room_id) - self.assertTrue(room_sync.timeline.limited) - self.assertEqual( - [e.event_id for e in room_sync.timeline.events], - [e3_event, e4_event], - ) - self.assertEqual( - [e.event_id for e in room_sync.state.values()], - [s2_event], - ) - - def test_state_includes_changes_on_forks_when_events_excluded(self) -> None: - """A variation on the previous test, but where one event is filtered - - The DAG is the same as the previous test, but E4 is excluded by the filter. - - E1 - ↗ ↖ - | S2 - | ↑ - --|------|---- - | | - E3 | - ↖ / - (E4) - - """ - - alice = self.register_user("alice", "password") - alice_tok = self.login(alice, "password") - alice_requester = create_requester(alice) - room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok) - - # Do an initial sync as Alice to get a known starting point. - initial_sync_result = self.get_success( - self.sync_handler.wait_for_sync_for_user( - alice_requester, - generate_sync_config(alice), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ) - ) - last_room_creation_event_id = ( - initial_sync_result.joined[0].timeline.events[-1].event_id - ) - - # Send a state event, and a regular event, both using the same prev ID - with self._patch_get_latest_events([last_room_creation_event_id]): - s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[ - "event_id" - ] - e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"] - - # Send a final event, joining the two branches of the dag - self.helper.send(room_id, "e4", type="not_a_normal_message", tok=alice_tok)[ - "event_id" - ] - - # do an incremental sync, with a filter that will only return E3, excluding S2 - # and E4. - incremental_sync = self.get_success( - self.sync_handler.wait_for_sync_for_user( - alice_requester, - generate_sync_config( - alice, - filter_collection=FilterCollection( - self.hs, - { - "room": { - "timeline": { - "limit": 1, - "not_types": ["not_a_normal_message"], - } - } - }, - ), - ), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - since_token=initial_sync_result.next_batch, - ) - ) - - # The state event should appear in the 'state' section of the response. - room_sync = incremental_sync.joined[0] - self.assertEqual(room_sync.room_id, room_id) - self.assertTrue(room_sync.timeline.limited) - self.assertEqual( - [e.event_id for e in room_sync.timeline.events], - [e3_event], - ) - self.assertEqual( - [e.event_id for e in room_sync.state.values()], - [s2_event], - ) - - def test_state_includes_changes_on_long_lived_forks(self) -> None: - """State changes that happen on a fork of the DAG must be included in `state` - - Given the following DAG: - - E1 - ↗ ↖ - | S2 - | ↑ - --|------|---- - E3 | - --|------|---- - | E4 - | | - - ... and a filter that means we only return 1 event, represented by the dashed - horizontal lines: `S2` must be included in the `state` section on the second sync. - """ - alice = self.register_user("alice", "password") - alice_tok = self.login(alice, "password") - alice_requester = create_requester(alice) - room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok) - - # Do an initial sync as Alice to get a known starting point. - initial_sync_result = self.get_success( - self.sync_handler.wait_for_sync_for_user( - alice_requester, - generate_sync_config(alice), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ) - ) - last_room_creation_event_id = ( - initial_sync_result.joined[0].timeline.events[-1].event_id - ) - - # Send a state event, and a regular event, both using the same prev ID - with self._patch_get_latest_events([last_room_creation_event_id]): - s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[ - "event_id" - ] - e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"] - - # Do an incremental sync, this will return E3 but *not* S2 at this - # point. - incremental_sync = self.get_success( - self.sync_handler.wait_for_sync_for_user( - alice_requester, - generate_sync_config( - alice, - filter_collection=FilterCollection( - self.hs, {"room": {"timeline": {"limit": 1}}} - ), - ), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - since_token=initial_sync_result.next_batch, - ) - ) - room_sync = incremental_sync.joined[0] - self.assertEqual(room_sync.room_id, room_id) - self.assertTrue(room_sync.timeline.limited) - self.assertEqual( - [e.event_id for e in room_sync.timeline.events], - [e3_event], - ) - self.assertEqual( - [e.event_id for e in room_sync.state.values()], - [], - ) - - # Now send another event that points to S2, but not E3. - with self._patch_get_latest_events([s2_event]): - e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"] - - # Doing an incremental sync should return S2 in state. - incremental_sync = self.get_success( - self.sync_handler.wait_for_sync_for_user( - alice_requester, - generate_sync_config( - alice, - filter_collection=FilterCollection( - self.hs, {"room": {"timeline": {"limit": 1}}} - ), - ), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - since_token=incremental_sync.next_batch, - ) - ) - room_sync = incremental_sync.joined[0] - self.assertEqual(room_sync.room_id, room_id) - self.assertFalse(room_sync.timeline.limited) - self.assertEqual( - [e.event_id for e in room_sync.timeline.events], - [e4_event], - ) - self.assertEqual( - [e.event_id for e in room_sync.state.values()], - [s2_event], - ) - - def test_state_includes_changes_on_ungappy_syncs(self) -> None: - """Test `state` where the sync is not gappy. - - We start with a DAG like this: - - E1 - ↗ ↖ - | S2 - | - --|--- - | - E3 - - ... and initialsync with `limit=1`, represented by the horizontal dashed line. - At this point, we do not expect S2 to appear in the response at all (since - it is excluded from the timeline by the `limit`, and the state is based on the - state after the most recent event before the sync token (E3), which doesn't - include S2. - - Now more events arrive, and we do an incremental sync: - - E1 - ↗ ↖ - | S2 - | ↑ - E3 | - ↑ | - --|------|---- - | | - E4 | - ↖ / - E5 - - This is the last chance for us to tell the client about S2, so it *must* be - included in the response. - """ - alice = self.register_user("alice", "password") - alice_tok = self.login(alice, "password") - alice_requester = create_requester(alice) - room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok) - - # Do an initial sync to get a known starting point. - initial_sync_result = self.get_success( - self.sync_handler.wait_for_sync_for_user( - alice_requester, - generate_sync_config(alice), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ) - ) - last_room_creation_event_id = ( - initial_sync_result.joined[0].timeline.events[-1].event_id - ) - - # Send a state event, and a regular event, both using the same prev ID - with self._patch_get_latest_events([last_room_creation_event_id]): - s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[ - "event_id" - ] - e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"] - - # Another initial sync, with limit=1 - initial_sync_result = self.get_success( - self.sync_handler.wait_for_sync_for_user( - alice_requester, - generate_sync_config( - alice, - filter_collection=FilterCollection( - self.hs, {"room": {"timeline": {"limit": 1}}} - ), - ), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ) - ) - room_sync = initial_sync_result.joined[0] - self.assertEqual(room_sync.room_id, room_id) - self.assertEqual( - [e.event_id for e in room_sync.timeline.events], - [e3_event], - ) - self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()]) - - # More events, E4 and E5 - with self._patch_get_latest_events([e3_event]): - e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"] - e5_event = self.helper.send(room_id, "e5", tok=alice_tok)["event_id"] - - # Now incremental sync - incremental_sync = self.get_success( - self.sync_handler.wait_for_sync_for_user( - alice_requester, - generate_sync_config(alice), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - since_token=initial_sync_result.next_batch, - ) - ) - - # The state event should appear in the 'state' section of the response. - room_sync = incremental_sync.joined[0] - self.assertEqual(room_sync.room_id, room_id) - self.assertFalse(room_sync.timeline.limited) - self.assertEqual( - [e.event_id for e in room_sync.timeline.events], - [e4_event, e5_event], - ) - self.assertEqual( - [e.event_id for e in room_sync.state.values()], - [s2_event], - ) - - @parameterized.expand( - [ - (False, False), - (True, False), - (False, True), - (True, True), - ] - ) - def test_archived_rooms_do_not_include_state_after_leave( - self, initial_sync: bool, empty_timeline: bool - ) -> None: - """If the user leaves the room, state changes that happen after they leave are not returned. - - We try with both a zero and a normal timeline limit, - and we try both an initial sync and an incremental sync for both. - """ - if empty_timeline and not initial_sync: - # FIXME synapse doesn't return the room at all in this situation! - self.skipTest("Synapse does not correctly handle this case") - - # Alice creates the room, and bob joins. - alice = self.register_user("alice", "password") - alice_tok = self.login(alice, "password") - - bob = self.register_user("bob", "password") - bob_tok = self.login(bob, "password") - bob_requester = create_requester(bob) - - room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok) - self.helper.join(room_id, bob, tok=bob_tok) - - initial_sync_result = self.get_success( - self.sync_handler.wait_for_sync_for_user( - bob_requester, - generate_sync_config(bob), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ) - ) - - # Alice sends a message and a state - before_message_event = self.helper.send(room_id, "before", tok=alice_tok)[ - "event_id" - ] - before_state_event = self.helper.send_state( - room_id, "test_state", {"body": "before"}, tok=alice_tok - )["event_id"] - - # Bob leaves - leave_event = self.helper.leave(room_id, bob, tok=bob_tok)["event_id"] - - # Alice sends some more stuff - self.helper.send(room_id, "after", tok=alice_tok)["event_id"] - self.helper.send_state(room_id, "test_state", {"body": "after"}, tok=alice_tok)[ - "event_id" - ] - - # And now, Bob resyncs. - filter_dict: JsonDict = {"room": {"include_leave": True}} - if empty_timeline: - filter_dict["room"]["timeline"] = {"limit": 0} - sync_room_result = self.get_success( - self.sync_handler.wait_for_sync_for_user( - bob_requester, - generate_sync_config( - bob, filter_collection=FilterCollection(self.hs, filter_dict) - ), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - since_token=None if initial_sync else initial_sync_result.next_batch, - ) - ).archived[0] - - if empty_timeline: - # The timeline should be empty - self.assertEqual(sync_room_result.timeline.events, []) - - # And the state should include the leave event... - self.assertEqual( - sync_room_result.state[("m.room.member", bob)].event_id, leave_event - ) - # ... and the state change before he left. - self.assertEqual( - sync_room_result.state[("test_state", "")].event_id, before_state_event - ) - else: - # The last three events in the timeline should be those leading up to the - # leave - self.assertEqual( - [e.event_id for e in sync_room_result.timeline.events[-3:]], - [before_message_event, before_state_event, leave_event], - ) - # ... And the state should be empty - self.assertEqual(sync_room_result.state, {}) - - def _patch_get_latest_events(self, latest_events: List[str]) -> ContextManager: - """Monkey-patch `get_prev_events_for_room` - - Returns a context manager which will replace the implementation of - `get_prev_events_for_room` with one which returns `latest_events`. - """ - return patch.object( - self.hs.get_datastores().main, - "get_prev_events_for_room", - new_callable=AsyncMock, - return_value=latest_events, - ) - - def test_call_invite_in_public_room_not_returned(self) -> None: - user = self.register_user("alice", "password") - tok = self.login(user, "password") - room_id = self.helper.create_room_as(user, is_public=True, tok=tok) - self.handler = self.hs.get_federation_handler() - federation_event_handler = self.hs.get_federation_event_handler() - - async def _check_event_auth( - origin: Optional[str], event: EventBase, context: EventContext - ) -> None: - pass - - federation_event_handler._check_event_auth = _check_event_auth # type: ignore[method-assign] - self.client = self.hs.get_federation_client() - - async def _check_sigs_and_hash_for_pulled_events_and_fetch( - dest: str, pdus: Collection[EventBase], room_version: RoomVersion - ) -> List[EventBase]: - return list(pdus) - - self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment] - - prev_events = self.get_success(self.store.get_prev_events_for_room(room_id)) - - # create a call invite event - call_event = event_from_pdu_json( - { - "type": EventTypes.CallInvite, - "content": {}, - "room_id": room_id, - "sender": user, - "depth": 32, - "prev_events": prev_events, - "auth_events": prev_events, - "origin_server_ts": self.clock.time_msec(), - }, - RoomVersions.V10, - ) - - self.assertEqual( - self.get_success( - federation_event_handler.on_receive_pdu("test.serv", call_event) - ), - None, - ) - - # check that it is in DB - recent_event = self.get_success(self.store.get_prev_events_for_room(room_id)) - self.assertIn(call_event.event_id, recent_event) - - # but that it does not come down /sync in public room - sync_result: SyncResult = self.get_success( - self.sync_handler.wait_for_sync_for_user( - create_requester(user), - generate_sync_config(user), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ) - ) - event_ids = [] - for event in sync_result.joined[0].timeline.events: - event_ids.append(event.event_id) - self.assertNotIn(call_event.event_id, event_ids) - - # it will come down in a private room, though - user2 = self.register_user("bob", "password") - tok2 = self.login(user2, "password") - private_room_id = self.helper.create_room_as( - user2, is_public=False, tok=tok2, extra_content={"preset": "private_chat"} - ) - - priv_prev_events = self.get_success( - self.store.get_prev_events_for_room(private_room_id) - ) - private_call_event = event_from_pdu_json( - { - "type": EventTypes.CallInvite, - "content": {}, - "room_id": private_room_id, - "sender": user, - "depth": 32, - "prev_events": priv_prev_events, - "auth_events": priv_prev_events, - "origin_server_ts": self.clock.time_msec(), - }, - RoomVersions.V10, - ) - - self.assertEqual( - self.get_success( - federation_event_handler.on_receive_pdu("test.serv", private_call_event) - ), - None, - ) - - recent_events = self.get_success( - self.store.get_prev_events_for_room(private_room_id) - ) - self.assertIn(private_call_event.event_id, recent_events) - - private_sync_result: SyncResult = self.get_success( - self.sync_handler.wait_for_sync_for_user( - create_requester(user2), - generate_sync_config(user2), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ) - ) - priv_event_ids = [] - for event in private_sync_result.joined[0].timeline.events: - priv_event_ids.append(event.event_id) - - self.assertIn(private_call_event.event_id, priv_event_ids) - def test_push_rules_with_bad_account_data(self) -> None: - """Some old accounts have managed to set a `m.push_rules` account data, - which we should ignore in /sync response. - """ - - user = self.register_user("alice", "password") - - # Insert the bad account data. - self.get_success( - self.store.add_account_data_for_user(user, AccountDataTypes.PUSH_RULES, {}) - ) - - sync_result: SyncResult = self.get_success( - self.sync_handler.wait_for_sync_for_user( - create_requester(user), - generate_sync_config(user), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - ) - ) - - for account_dict in sync_result.account_data: - if account_dict["type"] == AccountDataTypes.PUSH_RULES: - # We should have lots of push rules here, rather than the bad - # empty data. - self.assertNotEqual(account_dict["content"], {}) - return - - self.fail("No push rules found") - - def test_wait_for_future_sync_token(self) -> None: - """Test that if we receive a token that is ahead of our current token, - we'll wait until the stream position advances. - - This can happen if replication streams start lagging, and the client's - previous sync request was serviced by a worker ahead of ours. - """ - user = self.register_user("alice", "password") - - # We simulate a lagging stream by getting a stream ID from the ID gen - # and then waiting to mark it as "persisted". - presence_id_gen = self.store.get_presence_stream_id_gen() - ctx_mgr = presence_id_gen.get_next() - stream_id = self.get_success(ctx_mgr.__aenter__()) - - # Create the new token based on the stream ID above. - current_token = self.hs.get_event_sources().get_current_token() - since_token = current_token.copy_and_advance(StreamKeyType.PRESENCE, stream_id) - - sync_d = defer.ensureDeferred( - self.sync_handler.wait_for_sync_for_user( - create_requester(user), - generate_sync_config(user), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - since_token=since_token, - timeout=0, - ) - ) - - # This should block waiting for the presence stream to update - self.pump() - self.assertFalse(sync_d.called) - - # Marking the stream ID as persisted should unblock the request. - self.get_success(ctx_mgr.__aexit__(None, None, None)) - - self.get_success(sync_d, by=1.0) - - @parameterized.expand( - [(key,) for key in StreamKeyType.__members__.values()], - name_func=lambda func, _, param: f"{func.__name__}_{param.args[0].name}", - ) - def test_wait_for_invalid_future_sync_token( - self, stream_key: StreamKeyType - ) -> None: - """Like the previous test, except we give a token that has a stream - position ahead of what is in the DB, i.e. its invalid and we shouldn't - wait for the stream to advance (as it may never do so). - - This can happen due to older versions of Synapse giving out stream - positions without persisting them in the DB, and so on restart the - stream would get reset back to an older position. - """ - user = self.register_user("alice", "password") - - # Create a token and advance one of the streams. - current_token = self.hs.get_event_sources().get_current_token() - token_value = current_token.get_field(stream_key) - - # How we advance the streams depends on the type. - if isinstance(token_value, int): - since_token = current_token.copy_and_advance(stream_key, token_value + 1) - elif isinstance(token_value, MultiWriterStreamToken): - since_token = current_token.copy_and_advance( - stream_key, MultiWriterStreamToken(stream=token_value.stream + 1) - ) - elif isinstance(token_value, RoomStreamToken): - since_token = current_token.copy_and_advance( - stream_key, RoomStreamToken(stream=token_value.stream + 1) - ) - else: - raise Exception("Unreachable") - - sync_d = defer.ensureDeferred( - self.sync_handler.wait_for_sync_for_user( - create_requester(user), - generate_sync_config(user), - sync_version=SyncVersion.SYNC_V2, - request_key=generate_request_key(), - since_token=since_token, - timeout=0, - ) - ) - - # We should return without waiting for the presence stream to advance. - self.get_success(sync_d) +_request_key = 0 def generate_sync_config( - user_id: str, - device_id: Optional[str] = "device_id", - filter_collection: Optional[FilterCollection] = None, + user_id: str, device_id: Optional[str] = "device_id" ) -> SyncConfig: - """Generate a sync config (with a unique request key). - - Args: - user_id: user who is syncing. - device_id: device that is syncing. Defaults to "device_id". - filter_collection: filter to apply. Defaults to the default filter (ie, - return everything, with a default limit) - """ - if filter_collection is None: - filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION - + """Generate a sync config (with a unique request key).""" + global _request_key + _request_key += 1 return SyncConfig( user=UserID.from_string(user_id), - filter_collection=filter_collection, + filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION, is_guest=False, + request_key=("request_key", _request_key), device_id=device_id, ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 9d8960315f..c410fe7cc4 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -32,7 +31,7 @@ from twisted.web.resource import Resource from synapse.api.constants import EduTypes from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer -from synapse.handlers.typing import FORGET_TIMEOUT, TypingWriterHandler +from synapse.handlers.typing import TypingWriterHandler from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.server import HomeServer from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester @@ -501,54 +500,3 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): } ], ) - - def test_prune_typing_replication(self) -> None: - """Regression test for `get_all_typing_updates` breaking when we prune - old updates - """ - self.room_members = [U_APPLE, U_BANANA] - - instance_name = self.hs.get_instance_name() - - self.get_success( - self.handler.started_typing( - target_user=U_APPLE, - requester=create_requester(U_APPLE), - room_id=ROOM_ID, - timeout=10000, - ) - ) - - rows, _, _ = self.get_success( - self.handler.get_all_typing_updates( - instance_name=instance_name, - last_id=0, - current_id=self.handler.get_current_token(), - limit=100, - ) - ) - self.assertEqual(rows, [(1, [ROOM_ID, [U_APPLE.to_string()]])]) - - self.reactor.advance(20000) - - rows, _, _ = self.get_success( - self.handler.get_all_typing_updates( - instance_name=instance_name, - last_id=1, - current_id=self.handler.get_current_token(), - limit=100, - ) - ) - self.assertEqual(rows, [(2, [ROOM_ID, []])]) - - self.reactor.advance(FORGET_TIMEOUT) - - rows, _, _ = self.get_success( - self.handler.get_all_typing_updates( - instance_name=instance_name, - last_id=1, - current_id=self.handler.get_current_token(), - limit=100, - ) - ) - self.assertEqual(rows, []) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 878d9683b6..77c6cac449 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -1061,45 +1061,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): {alice: ProfileInfo(display_name=None, avatar_url=MXC_DUMMY)}, ) - def test_search_punctuation(self) -> None: - """Test that you can search for a user that includes punctuation""" - - searching_user = self.register_user("searcher", "password") - searching_user_tok = self.login("searcher", "password") - - room_id = self.helper.create_room_as( - searching_user, - room_version=RoomVersions.V1.identifier, - tok=searching_user_tok, - ) - - # We want to test searching for users of the form e.g. "user-1", with - # various punctuation. We also test both where the prefix is numeric and - # alphanumeric, as e.g. postgres tokenises "user-1" as "user" and "-1". - i = 1 - for char in ["-", ".", "_"]: - for use_numeric in [False, True]: - if use_numeric: - prefix1 = f"{i}" - prefix2 = f"{i+1}" - else: - prefix1 = f"a{i}" - prefix2 = f"a{i+1}" - - local_user_1 = self.register_user(f"user{char}{prefix1}", "password") - local_user_2 = self.register_user(f"user{char}{prefix2}", "password") - - self._add_user_to_room(room_id, RoomVersions.V1, local_user_1) - self._add_user_to_room(room_id, RoomVersions.V1, local_user_2) - - results = self.get_success( - self.handler.search_users(searching_user, local_user_1, 20) - )["results"] - received_user_id_ordering = [result["user_id"] for result in results] - self.assertSequenceEqual(received_user_id_ordering[:1], [local_user_1]) - - i += 2 - class TestUserDirSearchDisabled(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py
index 6e9a15c8ee..8fec3d6bb7 100644 --- a/tests/handlers/test_worker_lock.py +++ b/tests/handlers/test_worker_lock.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -27,7 +26,6 @@ from synapse.util import Clock from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.utils import test_timeout class WorkerLockTestCase(unittest.HomeserverTestCase): @@ -51,28 +49,6 @@ class WorkerLockTestCase(unittest.HomeserverTestCase): self.get_success(d2) self.get_success(lock2.__aexit__(None, None, None)) - def test_lock_contention(self) -> None: - """Test lock contention when a lot of locks wait on a single worker""" - - # It takes around 0.5s on a 5+ years old laptop - with test_timeout(5): - nb_locks = 500 - d = self._take_locks(nb_locks) - self.assertEqual(self.get_success(d), nb_locks) - - async def _take_locks(self, nb_locks: int) -> int: - locks = [ - self.hs.get_worker_locks_handler().acquire_lock("test_lock", "") - for _ in range(nb_locks) - ] - - nb_locks_taken = 0 - for lock in locks: - async with lock: - nb_locks_taken += 1 - - return nb_locks_taken - class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase): def prepare( diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 8e8621e348..4233b1f2cd 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/http/server/__init__.py b/tests/http/server/__init__.py
index dab387a504..3d833a2e44 100644 --- a/tests/http/server/__init__.py +++ b/tests/http/server/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 731b0c4e59..c1b5f2e3b9 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 721917f957..ff507c7a0f 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -37,155 +36,18 @@ from synapse.http.client import ( BlocklistingAgentWrapper, BlocklistingReactorWrapper, BodyExceededMaxSize, - MultipartResponse, _DiscardBodyWithMaxSizeProtocol, - _MultipartParserProtocol, read_body_with_max_size, - read_multipart_response, ) from tests.server import FakeTransport, get_clock from tests.unittest import TestCase -class ReadMultipartResponseTests(TestCase): - data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_" - data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" - - redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" - - def _build_multipart_response( - self, response_length: Union[int, str], max_length: int - ) -> Tuple[ - BytesIO, - "Deferred[MultipartResponse]", - _MultipartParserProtocol, - ]: - """Start reading the body, returns the response, result and proto""" - response = Mock(length=response_length) - result = BytesIO() - boundary = "6067d4698f8d40a0a794ea7d7379d53a" - deferred = read_multipart_response(response, result, boundary, max_length) - - # Fish the protocol out of the response. - protocol = response.deliverBody.call_args[0][0] - protocol.transport = Mock() - - return result, deferred, protocol - - def _assert_error( - self, - deferred: "Deferred[MultipartResponse]", - protocol: _MultipartParserProtocol, - ) -> None: - """Ensure that the expected error is received.""" - assert isinstance(deferred.result, Failure) - self.assertIsInstance(deferred.result.value, BodyExceededMaxSize) - assert protocol.transport is not None - # type-ignore: presumably abortConnection has been replaced with a Mock. - protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined] - - def _cleanup_error(self, deferred: "Deferred[MultipartResponse]") -> None: - """Ensure that the error in the Deferred is handled gracefully.""" - called = [False] - - def errback(f: Failure) -> None: - called[0] = True - - deferred.addErrback(errback) - self.assertTrue(called[0]) - - def test_parse_file(self) -> None: - """ - Check that a multipart response containing a file is properly parsed - into the json/file parts, and the json and file are properly captured - """ - result, deferred, protocol = self._build_multipart_response(249, 250) - - # Start sending data. - protocol.dataReceived(self.data1) - protocol.dataReceived(self.data2) - # Close the connection. - protocol.connectionLost(Failure(ResponseDone())) - - multipart_response: MultipartResponse = deferred.result # type: ignore[assignment] - - self.assertEqual(multipart_response.json, b"{}") - self.assertEqual(result.getvalue(), b"file_to_stream") - self.assertEqual(multipart_response.length, len(b"file_to_stream")) - self.assertEqual(multipart_response.content_type, b"text/plain") - self.assertEqual( - multipart_response.disposition, b"inline; filename=test_upload" - ) - - def test_parse_redirect(self) -> None: - """ - check that a multipart response containing a redirect is properly parsed and redirect url is - returned - """ - result, deferred, protocol = self._build_multipart_response(249, 250) - - # Start sending data. - protocol.dataReceived(self.redirect_data) - # Close the connection. - protocol.connectionLost(Failure(ResponseDone())) - - multipart_response: MultipartResponse = deferred.result # type: ignore[assignment] - - self.assertEqual(multipart_response.json, b"{}") - self.assertEqual(result.getvalue(), b"") - self.assertEqual( - multipart_response.url, b"https://cdn.example.org/ab/c1/2345.txt" - ) - - def test_too_large(self) -> None: - """A response which is too large raises an exception.""" - result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180) - - # Start sending data. - protocol.dataReceived(self.data1) - - self.assertEqual(result.getvalue(), b"file_") - self._assert_error(deferred, protocol) - self._cleanup_error(deferred) - - def test_additional_data(self) -> None: - """A connection can receive data after being closed.""" - result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180) - - # Start sending data. - protocol.dataReceived(self.data1) - self._assert_error(deferred, protocol) - - # More data might have come in. - protocol.dataReceived(self.data2) - - self.assertEqual(result.getvalue(), b"file_") - self._assert_error(deferred, protocol) - self._cleanup_error(deferred) - - def test_content_length(self) -> None: - """The body shouldn't be read (at all) if the Content-Length header is too large.""" - result, deferred, protocol = self._build_multipart_response(250, 1) - - # Deferred shouldn't be called yet. - self.assertFalse(deferred.called) - - # Start sending data. - protocol.dataReceived(self.data1) - self._assert_error(deferred, protocol) - self._cleanup_error(deferred) - - # The data is never consumed. - self.assertEqual(result.getvalue(), b"") - - class ReadBodyWithMaxSizeTests(TestCase): - def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[ - BytesIO, - "Deferred[int]", - _DiscardBodyWithMaxSizeProtocol, - ]: + def _build_response( + self, length: Union[int, str] = UNKNOWN_LENGTH + ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]: """Start reading the body, returns the response, result and proto""" response = Mock(length=length) result = BytesIO() diff --git a/tests/http/test_proxy.py b/tests/http/test_proxy.py
index 5895270494..4b2355536a 100644 --- a/tests/http/test_proxy.py +++ b/tests/http/test_proxy.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index f71e4c2b8f..8fc66aee26 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index 18af2735fe..d16dc419f3 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
index b7806fa947..bfcbe0fb10 100644 --- a/tests/http/test_simple_client.py +++ b/tests/http/test_simple_client.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/http/test_site.py b/tests/http/test_site.py
index bfa26a329c..3c54569e9a 100644 --- a/tests/http/test_site.py +++ b/tests/http/test_site.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
index f8f2260c2e..13b2754f7b 100644 --- a/tests/logging/__init__.py +++ b/tests/logging/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index c7ef2bd7a4..38144a9a45 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index f5412ac6e2..80e228b829 100644 --- a/tests/logging/test_remote_handler.py +++ b/tests/logging/test_remote_handler.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index ff85e067b7..4460b050d9 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/media/__init__.py b/tests/media/__init__.py
index b401d78ef1..3d833a2e44 100644 --- a/tests/media/__init__.py +++ b/tests/media/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/media/test_filepath.py b/tests/media/test_filepath.py
index cd21b369b4..9d8989ce34 100644 --- a/tests/media/test_filepath.py +++ b/tests/media/test_filepath.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/media/test_html_preview.py b/tests/media/test_html_preview.py
index d3f1e8833a..0ab33e87b4 100644 --- a/tests/media/test_html_preview.py +++ b/tests/media/test_html_preview.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/media/test_media_retention.py b/tests/media/test_media_retention.py
index 417d17ebd2..7f613f351b 100644 --- a/tests/media/test_media_retention.py +++ b/tests/media/test_media_retention.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index e55001fb40..a42383bcb6 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -24,7 +23,7 @@ import tempfile from binascii import unhexlify from io import BytesIO from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock from urllib import parse import attr @@ -36,12 +35,9 @@ from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor -from twisted.web.http_headers import Headers -from twisted.web.iweb import UNKNOWN_LENGTH, IResponse from twisted.web.resource import Resource from synapse.api.errors import Codes, HttpResponseException -from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable @@ -49,11 +45,11 @@ from synapse.media._base import FileInfo, ThumbnailInfo from synapse.media.filepath import MediaFilePaths from synapse.media.media_storage import MediaStorage, ReadableFileWrapper from synapse.media.storage_provider import FileStorageProviderBackend -from synapse.media.thumbnailer import ThumbnailProvider from synapse.module_api import ModuleApi from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers from synapse.rest import admin -from synapse.rest.client import login, media +from synapse.rest.client import login +from synapse.rest.media.thumbnail_resource import ThumbnailResource from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias from synapse.util import Clock @@ -61,7 +57,6 @@ from synapse.util import Clock from tests import unittest from tests.server import FakeChannel from tests.test_utils import SMALL_PNG -from tests.unittest import override_config from tests.utils import default_config @@ -128,7 +123,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): @attr.s(auto_attribs=True, slots=True, frozen=True) -class TestImage: +class _TestImage: """An image for testing thumbnailing with the expected results Attributes: @@ -157,54 +152,68 @@ class TestImage: is_inline: bool = True -small_png = TestImage( - SMALL_PNG, - b"image/png", - b".png", - unhexlify( - b"89504e470d0a1a0a0000000d4948445200000020000000200806" - b"000000737a7af40000001a49444154789cedc101010000008220" - b"ffaf6e484001000000ef0610200001194334ee0000000049454e" - b"44ae426082" - ), - unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000d49444154789c636060606000000005" - b"0001a5f645400000000049454e44ae426082" - ), -) - -small_png_with_transparency = TestImage( - unhexlify( - b"89504e470d0a1a0a0000000d49484452000000010000000101000" - b"00000376ef9240000000274524e5300010194fdae0000000a4944" - b"4154789c636800000082008177cd72b60000000049454e44ae426" - b"082" - ), - b"image/png", - b".png", - # Note that we don't check the output since it varies across - # different versions of Pillow. -) - -small_lossless_webp = TestImage( - unhexlify( - b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700" - ), - b"image/webp", - b".webp", -) - -empty_file = TestImage( - b"", - b"image/gif", - b".gif", - expected_found=False, - unable_to_thumbnail=True, -) - -SVG = TestImage( - b"""<?xml version="1.0"?> +@parameterized_class( + ("test_image",), + [ + # small png + ( + _TestImage( + SMALL_PNG, + b"image/png", + b".png", + unhexlify( + b"89504e470d0a1a0a0000000d4948445200000020000000200806" + b"000000737a7af40000001a49444154789cedc101010000008220" + b"ffaf6e484001000000ef0610200001194334ee0000000049454e" + b"44ae426082" + ), + unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000d49444154789c636060606000000005" + b"0001a5f645400000000049454e44ae426082" + ), + ), + ), + # small png with transparency. + ( + _TestImage( + unhexlify( + b"89504e470d0a1a0a0000000d49484452000000010000000101000" + b"00000376ef9240000000274524e5300010194fdae0000000a4944" + b"4154789c636800000082008177cd72b60000000049454e44ae426" + b"082" + ), + b"image/png", + b".png", + # Note that we don't check the output since it varies across + # different versions of Pillow. + ), + ), + # small lossless webp + ( + _TestImage( + unhexlify( + b"524946461a000000574542505650384c0d0000002f0000001007" + b"1011118888fe0700" + ), + b"image/webp", + b".webp", + ), + ), + # an empty file + ( + _TestImage( + b"", + b"image/gif", + b".gif", + expected_found=False, + unable_to_thumbnail=True, + ), + ), + # An SVG. + ( + _TestImage( + b"""<?xml version="1.0"?> <!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> @@ -213,26 +222,17 @@ SVG = TestImage( <circle cx="100" cy="100" r="50" stroke="black" stroke-width="5" fill="red" /> </svg>""", - b"image/svg", - b".svg", - expected_found=False, - unable_to_thumbnail=True, - is_inline=False, + b"image/svg", + b".svg", + expected_found=False, + unable_to_thumbnail=True, + is_inline=False, + ), + ), + ], ) -test_images = [ - small_png, - small_png_with_transparency, - small_lossless_webp, - empty_file, - SVG, -] -input_values = [(x,) for x in test_images] - - -@parameterized_class(("test_image",), input_values) class MediaRepoTests(unittest.HomeserverTestCase): - servlets = [media.register_servlets] - test_image: ClassVar[TestImage] + test_image: ClassVar[_TestImage] hijack_auth = True user_id = "@test:user" @@ -250,11 +250,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): destination: str, path: str, output_stream: BinaryIO, - download_ratelimiter: Ratelimiter, - ip_address: Any, - max_size: int, args: Optional[QueryParams] = None, retry_on_dns_fail: bool = True, + max_size: Optional[int] = None, ignore_backoff: bool = False, follow_redirects: bool = False, ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": @@ -503,7 +501,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): params = "?width=32&height=32&method=scale" channel = self.make_request( "GET", - f"/_matrix/media/r0/thumbnail/{self.media_id}{params}", + f"/_matrix/media/v3/thumbnail/{self.media_id}{params}", shorthand=False, await_result=False, ) @@ -531,7 +529,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - f"/_matrix/media/r0/thumbnail/{self.media_id}{params}", + f"/_matrix/media/v3/thumbnail/{self.media_id}{params}", shorthand=False, await_result=False, ) @@ -572,6 +570,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): await_result=False, ) self.pump() + headers = { b"Content-Length": [b"%d" % (len(self.test_image.data))], b"Content-Type": [self.test_image.content_type], @@ -580,6 +579,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): (self.test_image.data, (len(self.test_image.data), headers)) ) self.pump() + if expected_found: self.assertEqual(channel.code, 200) @@ -624,12 +624,12 @@ class MediaRepoTests(unittest.HomeserverTestCase): content_type = self.test_image.content_type.decode() media_repo = self.hs.get_media_repository() - thumbnail_provider = ThumbnailProvider( + thumbnail_resouce = ThumbnailResource( self.hs, media_repo, media_repo.media_storage ) self.assertIsNotNone( - thumbnail_provider._select_thumbnail( + thumbnail_resouce._select_thumbnail( desired_width=desired_size, desired_height=desired_size, desired_method=method, @@ -878,249 +878,3 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): tok=self.tok, expect_code=400, ) - - -class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - - self.storage_path = self.mktemp() - self.media_store_path = self.mktemp() - os.mkdir(self.storage_path) - os.mkdir(self.media_store_path) - config["media_store_path"] = self.media_store_path - - provider_config = { - "module": "synapse.media.storage_provider.FileStorageProviderBackend", - "store_local": True, - "store_synchronous": False, - "store_remote": True, - "config": {"directory": self.storage_path}, - } - - config["media_storage_providers"] = [provider_config] - - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.repo = hs.get_media_repository() - self.client = hs.get_federation_http_client() - self.store = hs.get_datastores().main - - def create_resource_dict(self) -> Dict[str, Resource]: - # We need to manually set the resource tree to include media, the - # default only does `/_matrix/client` APIs. - return {"/_matrix/media": self.hs.get_media_repository_resource()} - - # mock actually reading file body - def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred: - d: Deferred = defer.Deferred() - d.callback(31457280) - return d - - def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred: - d: Deferred = defer.Deferred() - d.callback(52428800) - return d - - @patch( - "synapse.http.matrixfederationclient.read_body_with_max_size", - read_body_with_max_size_30MiB, - ) - def test_download_ratelimit_default(self) -> None: - """ - Test remote media download ratelimiting against default configuration - 500MB bucket - and 87kb/second drain rate - """ - - # mock out actually sending the request, returns a 30MiB response - async def _send_request(*args: Any, **kwargs: Any) -> IResponse: - resp = MagicMock(spec=IResponse) - resp.code = 200 - resp.length = 31457280 - resp.headers = Headers({"Content-Type": ["application/octet-stream"]}) - resp.phrase = b"OK" - return resp - - self.client._send_request = _send_request # type: ignore - - # first request should go through - channel = self.make_request( - "GET", - "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz", - shorthand=False, - ) - assert channel.code == 200 - - # next 15 should go through - for i in range(15): - channel2 = self.make_request( - "GET", - f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}", - shorthand=False, - ) - assert channel2.code == 200 - - # 17th will hit ratelimit - channel3 = self.make_request( - "GET", - "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx", - shorthand=False, - ) - assert channel3.code == 429 - - # however, a request from a different IP will go through - channel4 = self.make_request( - "GET", - "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz", - shorthand=False, - client_ip="187.233.230.159", - ) - assert channel4.code == 200 - - # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another - # 30MiB download is authorized - The last download was blocked at 503,316,480. - # The next download will be authorized when bucket hits 492,830,720 - # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760 - # needs to drain before another download will be authorized, that will take ~= - # 2 minutes (10,485,760/89,088/60) - self.reactor.pump([2.0 * 60.0]) - - # enough has drained and next request goes through - channel5 = self.make_request( - "GET", - "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb", - shorthand=False, - ) - assert channel5.code == 200 - - @override_config( - { - "remote_media_download_per_second": "50M", - "remote_media_download_burst_count": "50M", - } - ) - @patch( - "synapse.http.matrixfederationclient.read_body_with_max_size", - read_body_with_max_size_50MiB, - ) - def test_download_rate_limit_config(self) -> None: - """ - Test that download rate limit config options are correctly picked up and applied - """ - - async def _send_request(*args: Any, **kwargs: Any) -> IResponse: - resp = MagicMock(spec=IResponse) - resp.code = 200 - resp.length = 52428800 - resp.headers = Headers({"Content-Type": ["application/octet-stream"]}) - resp.phrase = b"OK" - return resp - - self.client._send_request = _send_request # type: ignore - - # first request should go through - channel = self.make_request( - "GET", - "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz", - shorthand=False, - ) - assert channel.code == 200 - - # immediate second request should fail - channel = self.make_request( - "GET", - "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1", - shorthand=False, - ) - assert channel.code == 429 - - # advance half a second - self.reactor.pump([0.5]) - - # request still fails - channel = self.make_request( - "GET", - "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2", - shorthand=False, - ) - assert channel.code == 429 - - # advance another half second - self.reactor.pump([0.5]) - - # enough has drained from bucket and request is successful - channel = self.make_request( - "GET", - "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3", - shorthand=False, - ) - assert channel.code == 200 - - @override_config({"remote_media_download_burst_count": "87M"}) - @patch( - "synapse.http.matrixfederationclient.read_body_with_max_size", - read_body_with_max_size_30MiB, - ) - def test_download_ratelimit_unknown_length(self) -> None: - """ - Test that if no content-length is provided, ratelimit will still be applied after - download once length is known - """ - - # mock out actually sending the request - async def _send_request(*args: Any, **kwargs: Any) -> IResponse: - resp = MagicMock(spec=IResponse) - resp.code = 200 - resp.length = UNKNOWN_LENGTH - resp.headers = Headers({"Content-Type": ["application/octet-stream"]}) - resp.phrase = b"OK" - return resp - - self.client._send_request = _send_request # type: ignore - - # 3 requests should go through (note 3rd one would technically violate ratelimit but - # is applied *after* download - the next one will be ratelimited) - for i in range(3): - channel = self.make_request( - "GET", - f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}", - shorthand=False, - ) - assert channel.code == 200 - - # 4th will hit ratelimit - channel2 = self.make_request( - "GET", - "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx", - shorthand=False, - ) - assert channel2.code == 429 - - @override_config({"max_upload_size": "29M"}) - @patch( - "synapse.http.matrixfederationclient.read_body_with_max_size", - read_body_with_max_size_30MiB, - ) - def test_max_download_respected(self) -> None: - """ - Test that the max download size is enforced - note that max download size is determined - by the max_upload_size - """ - - # mock out actually sending the request - async def _send_request(*args: Any, **kwargs: Any) -> IResponse: - resp = MagicMock(spec=IResponse) - resp.code = 200 - resp.length = 31457280 - resp.headers = Headers({"Content-Type": ["application/octet-stream"]}) - resp.phrase = b"OK" - return resp - - self.client._send_request = _send_request # type: ignore - - channel = self.make_request( - "GET", "/_matrix/media/v3/download/remote.org/abcd", shorthand=False - ) - assert channel.code == 502 - assert channel.json_body["errcode"] == "M_TOO_LARGE" diff --git a/tests/media/test_oembed.py b/tests/media/test_oembed.py
index 29d4580697..abcb33a3b4 100644 --- a/tests/media/test_oembed.py +++ b/tests/media/test_oembed.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py
index 0ae414d408..8d3aa60657 100644 --- a/tests/media/test_url_previewer.py +++ b/tests/media/test_url_previewer.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py
index 80f24814e8..c989098685 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/module_api/test_account_data_manager.py b/tests/module_api/test_account_data_manager.py
index fd87eaffd0..18cd53673b 100644 --- a/tests/module_api/test_account_data_manager.py +++ b/tests/module_api/test_account_data_manager.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index b6ba472d7d..ce142e919f 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -688,7 +687,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): channel = self.make_request( "GET", - "/notifications", + "/notifications?from=", access_token=tok, ) self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/module_api/test_event_unsigned_addition.py b/tests/module_api/test_event_unsigned_addition.py
index c429eff4d6..39ab52174c 100644 --- a/tests/module_api/test_event_unsigned_addition.py +++ b/tests/module_api/test_event_unsigned_addition.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index fc73f3dc2a..e0d6d027bf 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index e0aab1c046..c927a73fa6 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py
@@ -205,24 +205,8 @@ class EmailPusherTests(HomeserverTestCase): # Multipart: plain text, base 64 encoded; html, base 64 encoded multipart_msg = email.message_from_bytes(msg) - - # Extract the text (non-HTML) portion of the multipart Message, - # as a Message. - txt_message = multipart_msg.get_payload(i=0) - assert isinstance(txt_message, email.message.Message) - - # Extract the actual bytes from the Message object, and decode them to a `str`. - txt_bytes = txt_message.get_payload(decode=True) - assert isinstance(txt_bytes, bytes) - txt = txt_bytes.decode() - - # Do the same for the HTML portion of the multipart Message. - html_message = multipart_msg.get_payload(i=1) - assert isinstance(html_message, email.message.Message) - html_bytes = html_message.get_payload(decode=True) - assert isinstance(html_bytes, bytes) - html = html_bytes.decode() - + txt = multipart_msg.get_payload()[0].get_payload(decode=True).decode() + html = multipart_msg.get_payload()[1].get_payload(decode=True).decode() self.assertIn("/_synapse/client/unsubscribe", txt) self.assertIn("/_synapse/client/unsubscribe", html) @@ -363,17 +347,12 @@ class EmailPusherTests(HomeserverTestCase): # That email should contain the room's avatar msg: bytes = args[5] # Multipart: plain text, base 64 encoded; html, base 64 encoded - - # Extract the html Message object from the Multipart Message. - # We need the asserts to convince mypy that this is OK. - html_message = email.message_from_bytes(msg).get_payload(i=1) - assert isinstance(html_message, email.message.Message) - - # Extract the `bytes` from the html Message object, and decode to a `str`. - html = html_message.get_payload(decode=True) - assert isinstance(html, bytes) - html = html.decode() - + html = ( + email.message_from_bytes(msg) + .get_payload()[1] + .get_payload(decode=True) + .decode() + ) self.assertIn("_matrix/media/v1/thumbnail/DUMMY_MEDIA_ID", html) def test_empty_room(self) -> None: diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index bcca472617..dce00d8b7f 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py
@@ -26,8 +26,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfig, PusherConfigException -from synapse.rest.admin.experimental_features import ExperimentalFeature -from synapse.rest.client import login, push_rule, pusher, receipts, room, versions +from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -43,7 +42,6 @@ class HTTPPusherTests(HomeserverTestCase): receipts.register_servlets, push_rule.register_servlets, pusher.register_servlets, - versions.register_servlets, ] user_id = True hijack_auth = False @@ -971,84 +969,6 @@ class HTTPPusherTests(HomeserverTestCase): lookup_result.device_id, ) - def test_device_id_feature_flag(self) -> None: - """Tests that a pusher created with a given device ID shows that device ID in - GET /pushers requests when feature is enabled for the user - """ - user_id = self.register_user("user", "pass") - access_token = self.login("user", "pass") - - # We create the pusher with an HTTP request rather than with - # _make_user_with_pusher so that we can test the device ID is correctly set when - # creating a pusher via an API call. - self.make_request( - method="POST", - path="/pushers/set", - content={ - "kind": "http", - "app_id": "m.http", - "app_display_name": "HTTP Push Notifications", - "device_display_name": "pushy push", - "pushkey": "a@example.com", - "lang": "en", - "data": {"url": "http://example.com/_matrix/push/v1/notify"}, - }, - access_token=access_token, - ) - - # Look up the user info for the access token so we can compare the device ID. - store = self.hs.get_datastores().main - lookup_result = self.get_success(store.get_user_by_access_token(access_token)) - assert lookup_result is not None - - # Check field is not there before we enable the feature flag - channel = self.make_request("GET", "/pushers", access_token=access_token) - self.assertEqual(channel.code, 200) - self.assertEqual(len(channel.json_body["pushers"]), 1) - self.assertNotIn( - "org.matrix.msc3881.device_id", channel.json_body["pushers"][0] - ) - - self.get_success( - store.set_features_for_user(user_id, {ExperimentalFeature.MSC3881: True}) - ) - - # Get the user's devices and check it has the correct device ID. - channel = self.make_request("GET", "/pushers", access_token=access_token) - self.assertEqual(channel.code, 200) - self.assertEqual(len(channel.json_body["pushers"]), 1) - self.assertEqual( - channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"], - lookup_result.device_id, - ) - - def test_msc3881_client_versions_flag(self) -> None: - """Tests that MSC3881 only appears in /versions if user has it enabled.""" - - user_id = self.register_user("user", "pass") - access_token = self.login("user", "pass") - - # Check feature is disabled in /versions - channel = self.make_request( - "GET", "/_matrix/client/versions", access_token=access_token - ) - self.assertEqual(channel.code, 200) - self.assertFalse(channel.json_body["unstable_features"]["org.matrix.msc3881"]) - - # Enable feature for user - self.get_success( - self.hs.get_datastores().main.set_features_for_user( - user_id, {ExperimentalFeature.MSC3881: True} - ) - ) - - # Check feature is now enabled in /versions for user - channel = self.make_request( - "GET", "/_matrix/client/versions", access_token=access_token - ) - self.assertEqual(channel.code, 200) - self.assertTrue(channel.json_body["unstable_features"]["org.matrix.msc3881"]) - @override_config({"push": {"jitter_delay": "10s"}}) def test_jitter(self) -> None: """Tests that enabling jitter actually delays sending push.""" diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py
index bd42fc0580..1f41c23e73 100644 --- a/tests/push/test_presentable_names.py +++ b/tests/push/test_presentable_names.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 420fbea998..b129332bb7 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/__init__.py b/tests/replication/__init__.py
index 587ee42067..3d833a2e44 100644 --- a/tests/replication/__init__.py +++ b/tests/replication/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 8437da1cdd..d2220f8195 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -495,9 +495,9 @@ class FakeRedisPubSubServer: """A fake Redis server for pub/sub.""" def __init__(self) -> None: - self._subscribers_by_channel: Dict[bytes, Set["FakeRedisPubSubProtocol"]] = ( - defaultdict(set) - ) + self._subscribers_by_channel: Dict[ + bytes, Set["FakeRedisPubSubProtocol"] + ] = defaultdict(set) def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None: """A connection has called SUBSCRIBE""" diff --git a/tests/replication/http/__init__.py b/tests/replication/http/__init__.py
index dab387a504..3d833a2e44 100644 --- a/tests/replication/http/__init__.py +++ b/tests/replication/http/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index 2eaad3707a..f8a07b34e8 100644 --- a/tests/replication/http/test__base.py +++ b/tests/replication/http/test__base.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/storage/__init__.py b/tests/replication/storage/__init__.py
index 587ee42067..3d833a2e44 100644 --- a/tests/replication/storage/__init__.py +++ b/tests/replication/storage/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/storage/_base.py b/tests/replication/storage/_base.py
index 27dff0034f..357ff79ed3 100644 --- a/tests/replication/storage/_base.py +++ b/tests/replication/storage/_base.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py
index 1afe523d02..69350963f2 100644 --- a/tests/replication/storage/test_events.py +++ b/tests/replication/storage/test_events.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -30,16 +29,19 @@ from synapse.api.constants import ReceiptTypes from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext +from synapse.handlers.room import RoomEventSource from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import ( NotifCounts, RoomNotifCounts, ) from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.roommember import RoomsForUser +from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition from synapse.util import Clock +from tests.server import FakeTransport + from ._base import BaseWorkerStoreTestCase USER_ID = "@feeling:test" @@ -138,7 +140,6 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): self.persist(type="m.room.create", key="", creator=USER_ID) self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") - assert event.internal_metadata.instance_name is not None assert event.internal_metadata.stream_ordering is not None self.replicate() @@ -152,10 +153,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): USER_ID, "invite", event.event_id, - PersistedEventPosition( - event.internal_metadata.instance_name, - event.internal_metadata.stream_ordering, - ), + event.internal_metadata.stream_ordering, RoomVersions.V1.identifier, ) ], @@ -218,6 +216,122 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): ), ) + def test_get_rooms_for_user_with_stream_ordering(self) -> None: + """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated + by rows in the events stream + """ + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") + self.replicate() + self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set()) + + j2 = self.persist( + type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" + ) + assert j2.internal_metadata.stream_ordering is not None + self.replicate() + + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering + ) + self.check( + "get_rooms_for_user_with_stream_ordering", + (USER_ID_2,), + {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, + ) + + def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist( + self, + ) -> None: + """Check that current_state invalidation happens correctly with multiple events + in the persistence batch. + + This test attempts to reproduce a race condition between the event persistence + loop and a worker-based Sync handler. + + The problem occurred when the master persisted several events in one batch. It + only updates the current_state at the end of each batch, so the obvious thing + to do is then to issue a current_state_delta stream update corresponding to the + last stream_id in the batch. + + However, that raises the possibility that a worker will see the replication + notification for a join event before the current_state caches are invalidated. + + The test involves: + * creating a join and a message event for a user, and persisting them in the + same batch + + * controlling the replication stream so that updates are sent gradually + + * between each bunch of replication updates, check that we see a consistent + snapshot of the state. + """ + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") + self.replicate() + self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set()) + + # limit the replication rate + repl_transport = self._server_transport + assert isinstance(repl_transport, FakeTransport) + repl_transport.autoflush = False + + # build the join and message events and persist them in the same batch. + logger.info("----- build test events ------") + j2, j2ctx = self.build_event( + type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" + ) + msg, msgctx = self.build_event() + self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)])) + self.replicate() + assert j2.internal_metadata.stream_ordering is not None + + event_source = RoomEventSource(self.hs) + event_source.store = self.worker_store + current_token = event_source.get_current_key() + + # gradually stream out the replication + while repl_transport.buffer: + logger.info("------ flush ------") + repl_transport.flush(30) + self.pump(0) + + prev_token = current_token + current_token = event_source.get_current_key() + + # attempt to replicate the behaviour of the sync handler. + # + # First, we get a list of the rooms we are joined to + joined_rooms = self.get_success( + self.worker_store.get_rooms_for_user_with_stream_ordering(USER_ID_2) + ) + + # Then, we get a list of the events since the last sync + membership_changes = self.get_success( + self.worker_store.get_membership_changes_for_user( + USER_ID_2, prev_token, current_token + ) + ) + + logger.info( + "%s->%s: joined_rooms=%r membership_changes=%r", + prev_token, + current_token, + joined_rooms, + membership_changes, + ) + + # the membership change is only any use to us if the room is in the + # joined_rooms list. + if membership_changes: + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering + ) + self.assertEqual( + joined_rooms, + {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, + ) + event_id = 0 def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase: diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 6dea29ae15..9d934a8dfe 100644 --- a/tests/replication/tcp/streams/test_account_data.py +++ b/tests/replication/tcp/streams/test_account_data.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index abd4e0b8cd..e3d4c01aae 100644 --- a/tests/replication/tcp/streams/test_partial_state.py +++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/tcp/streams/test_to_device.py b/tests/replication/tcp/streams/test_to_device.py
index cb07e93d6b..e9e9590396 100644 --- a/tests/replication/tcp/streams/test_to_device.py +++ b/tests/replication/tcp/streams/test_to_device.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index b1c2f5b03b..198045dd86 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index 87a2365ae3..10efc00679 100644 --- a/tests/replication/tcp/test_commands.py +++ b/tests/replication/tcp/test_commands.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index a8eb7fc523..ecdf8e6679 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index 7820de8acc..517e30d160 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 298c662555..fb8e8ec3e5 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 14c9483f2b..2219ade77e 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 4429d0f4e2..6ebd54d8f4 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
index 1e7183edaa..bbf3ad371d 100644 --- a/tests/replication/test_module_cache_invalidation.py +++ b/tests/replication/test_module_cache_invalidation.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 6fc4600c41..70c10d126d 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -28,7 +27,7 @@ from twisted.web.http import HTTPChannel from twisted.web.server import Request from synapse.rest import admin -from synapse.rest.client import login, media +from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -255,238 +254,6 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): return sum(len(files) for _, _, files in os.walk(path)) -class AuthenticatedMediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): - """Checks running multiple media repos work correctly using autheticated media paths""" - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - media.register_servlets, - ] - - file_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n" - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.user_id = self.register_user("user", "pass") - self.access_token = self.login("user", "pass") - - self.reactor.lookups["example.com"] = "1.2.3.4" - - def default_config(self) -> dict: - conf = super().default_config() - conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] - return conf - - def make_worker_hs( - self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any - ) -> HomeServer: - worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs) - # Force the media paths onto the replication resource. - worker_hs.get_media_repository_resource().register_servlets( - self._hs_to_site[worker_hs].resource, worker_hs - ) - return worker_hs - - def _get_media_req( - self, hs: HomeServer, target: str, media_id: str - ) -> Tuple[FakeChannel, Request]: - """Request some remote media from the given HS by calling the download - API. - - This then triggers an outbound request from the HS to the target. - - Returns: - The channel for the *client* request and the *outbound* request for - the media which the caller should respond to. - """ - channel = make_request( - self.reactor, - self._hs_to_site[hs], - "GET", - f"/_matrix/client/v1/media/download/{target}/{media_id}", - shorthand=False, - access_token=self.access_token, - await_result=False, - ) - self.pump() - - clients = self.reactor.tcpClients - self.assertGreaterEqual(len(clients), 1) - (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - - # build the test server - server_factory = Factory.forProtocol(HTTPChannel) - # Request.finish expects the factory to have a 'log' method. - server_factory.log = _log_request - - server_tls_protocol = wrap_server_factory_for_tls( - server_factory, self.reactor, sanlist=[b"DNS:example.com"] - ).buildProtocol(None) - - # now, tell the client protocol factory to build the client protocol (it will be a - # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an - # HTTP11ClientProtocol) and wire the output of said protocol up to the server via - # a FakeTransport. - # - # Normally this would be done by the TCP socket code in Twisted, but we are - # stubbing that out here. - client_protocol = client_factory.buildProtocol(None) - client_protocol.makeConnection( - FakeTransport(server_tls_protocol, self.reactor, client_protocol) - ) - - # tell the server tls protocol to send its stuff back to the client, too - server_tls_protocol.makeConnection( - FakeTransport(client_protocol, self.reactor, server_tls_protocol) - ) - - # fish the test server back out of the server-side TLS protocol. - http_server: HTTPChannel = server_tls_protocol.wrappedProtocol - - # give the reactor a pump to get the TLS juices flowing. - self.reactor.pump((0.1,)) - - self.assertEqual(len(http_server.requests), 1) - request = http_server.requests[0] - - self.assertEqual(request.method, b"GET") - self.assertEqual( - request.path, - f"/_matrix/federation/v1/media/download/{media_id}".encode(), - ) - self.assertEqual( - request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] - ) - - return channel, request - - def test_basic(self) -> None: - """Test basic fetching of remote media from a single worker.""" - hs1 = self.make_worker_hs("synapse.app.generic_worker") - - channel, request = self._get_media_req(hs1, "example.com:443", "ABC123") - - request.setResponseCode(200) - request.responseHeaders.setRawHeaders( - b"Content-Type", - ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], - ) - request.write(self.file_data) - request.finish() - - self.pump(0.1) - - self.assertEqual(channel.code, 200) - self.assertEqual(channel.result["body"], b"file_to_stream") - - def test_download_simple_file_race(self) -> None: - """Test that fetching remote media from two different processes at the - same time works. - """ - hs1 = self.make_worker_hs("synapse.app.generic_worker") - hs2 = self.make_worker_hs("synapse.app.generic_worker") - - start_count = self._count_remote_media() - - # Make two requests without responding to the outbound media requests. - channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123") - channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123") - - # Respond to the first outbound media request and check that the client - # request is successful - request1.setResponseCode(200) - request1.responseHeaders.setRawHeaders( - b"Content-Type", - ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], - ) - request1.write(self.file_data) - request1.finish() - - self.pump(0.1) - - self.assertEqual(channel1.code, 200, channel1.result["body"]) - self.assertEqual(channel1.result["body"], b"file_to_stream") - - # Now respond to the second with the same content. - request2.setResponseCode(200) - request2.responseHeaders.setRawHeaders( - b"Content-Type", - ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], - ) - request2.write(self.file_data) - request2.finish() - - self.pump(0.1) - - self.assertEqual(channel2.code, 200, channel2.result["body"]) - self.assertEqual(channel2.result["body"], b"file_to_stream") - - # We expect only one new file to have been persisted. - self.assertEqual(start_count + 1, self._count_remote_media()) - - def test_download_image_race(self) -> None: - """Test that fetching remote *images* from two different processes at - the same time works. - - This checks that races generating thumbnails are handled correctly. - """ - hs1 = self.make_worker_hs("synapse.app.generic_worker") - hs2 = self.make_worker_hs("synapse.app.generic_worker") - - start_count = self._count_remote_thumbnails() - - channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1") - channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1") - - request1.setResponseCode(200) - request1.responseHeaders.setRawHeaders( - b"Content-Type", - ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], - ) - img_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: image/png\r\nContent-Disposition: inline; filename=test_img\r\n\r\n" - request1.write(img_data) - request1.write(SMALL_PNG) - request1.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n") - request1.finish() - - self.pump(0.1) - - self.assertEqual(channel1.code, 200, channel1.result["body"]) - self.assertEqual(channel1.result["body"], SMALL_PNG) - - request2.setResponseCode(200) - request2.responseHeaders.setRawHeaders( - b"Content-Type", - ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"], - ) - request2.write(img_data) - request2.write(SMALL_PNG) - request2.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n") - request2.finish() - - self.pump(0.1) - - self.assertEqual(channel2.code, 200, channel2.result["body"]) - self.assertEqual(channel2.result["body"], SMALL_PNG) - - # We expect only three new thumbnails to have been persisted. - self.assertEqual(start_count + 3, self._count_remote_thumbnails()) - - def _count_remote_media(self) -> int: - """Count the number of files in our remote media directory.""" - path = os.path.join( - self.hs.get_media_repository().primary_base_path, "remote_content" - ) - return sum(len(files) for _, _, files in os.walk(path)) - - def _count_remote_thumbnails(self) -> int: - """Count the number of files in our remote thumbnails directory.""" - path = os.path.join( - self.hs.get_media_repository().primary_base_path, "remote_thumbnail" - ) - return sum(len(files) for _, _, files in os.walk(path)) - - def _log_request(request: Request) -> None: """Implements Factory.log, which is expected by Request.finish""" logger.info("Completed request %s", request) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 1b0bdc262a..7002521893 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index ce6ad75901..04e58cf5df 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/replication/test_sharded_receipts.py b/tests/replication/test_sharded_receipts.py
index e400267819..81b319c858 100644 --- a/tests/replication/test_sharded_receipts.py +++ b/tests/replication/test_sharded_receipts.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/__init__.py b/tests/rest/__init__.py
index 6a72062b0c..3d833a2e44 100644 --- a/tests/rest/__init__.py +++ b/tests/rest/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 6351326fff..defccd7d12 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -384,7 +383,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase): "PUT", url, content={ - "features": {"msc3881": True}, + "features": {"msc3026": True, "msc3881": True}, }, access_token=self.admin_user_tok, ) @@ -401,6 +400,10 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual( True, + channel.json_body["features"]["msc3026"], + ) + self.assertEqual( + True, channel.json_body["features"]["msc3881"], ) @@ -409,7 +412,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", url, - content={"features": {"msc3881": False}}, + content={"features": {"msc3026": False}}, access_token=self.admin_user_tok, ) self.assertEqual(channel.code, 200) @@ -425,15 +428,23 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual( False, + channel.json_body["features"]["msc3026"], + ) + self.assertEqual( + True, channel.json_body["features"]["msc3881"], ) + self.assertEqual( + False, + channel.json_body["features"]["msc3967"], + ) # test nothing blows up if you try to disable a feature that isn't already enabled url = f"{self.url}/{self.other_user}" channel = self.make_request( "PUT", url, - content={"features": {"msc3881": False}}, + content={"features": {"msc3026": False}}, access_token=self.admin_user_tok, ) self.assertEqual(channel.code, 200) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index f33aada64b..caea9d4415 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index a88c77bd19..c8f6fa105a 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 Dirk Klimpel # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index feb410a11d..3e695e9700 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 Dirk Klimpel # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -24,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client import login, reporting, room +from synapse.rest.client import login, report_event, room from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -37,7 +36,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets, login.register_servlets, room.register_servlets, - reporting.register_servlets, + report_event.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -453,7 +452,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets, login.register_servlets, room.register_servlets, - reporting.register_servlets, + report_event.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index c2015774a1..1cdb1105eb 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -778,81 +777,20 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): self.assertEqual(number_rooms, len(channel.json_body["rooms"])) self._check_fields(channel.json_body["rooms"]) - def test_room_filtering(self) -> None: - """Tests that rooms are correctly filtered""" - - # Create two rooms on the homeserver. Each has a different remote homeserver - # participating in it. - other_destination = "other.destination.org" - room_ids_self_dest = self._create_destination_rooms(2, destination=self.dest) - room_ids_other_dest = self._create_destination_rooms( - 1, destination=other_destination - ) - - # Ask for the rooms that `self.dest` is participating in. - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok) - self.assertEqual(200, channel.code, msg=channel.json_body) - - # Verify that we received only the rooms that `self.dest` is participating in. - # This assertion method name is a bit misleading. It does check that both lists - # contain the same items, and the same counts. - self.assertCountEqual( - [r["room_id"] for r in channel.json_body["rooms"]], room_ids_self_dest - ) - self.assertEqual(channel.json_body["total"], len(room_ids_self_dest)) - - # Ask for the rooms that `other_destination` is participating in. - channel = self.make_request( - "GET", - self.url.replace(self.dest, other_destination), - access_token=self.admin_user_tok, - ) - self.assertEqual(200, channel.code, msg=channel.json_body) - - # Verify that we received only the rooms that `other_destination` is - # participating in. - self.assertCountEqual( - [r["room_id"] for r in channel.json_body["rooms"]], room_ids_other_dest - ) - self.assertEqual(channel.json_body["total"], len(room_ids_other_dest)) - - def _create_destination_rooms( - self, - number_rooms: int, - destination: Optional[str] = None, - ) -> List[str]: - """ - Create the given number of rooms. The given `destination` homeserver will - be recorded as a participant. + def _create_destination_rooms(self, number_rooms: int) -> None: + """Create a number rooms for destination Args: number_rooms: Number of rooms to be created - destination: The domain of the homeserver that will be considered - as a participant in the rooms. - - Returns: - The IDs of the rooms that have been created. """ - room_ids = [] - - # If no destination was provided, default to `self.dest`. - if destination is None: - destination = self.dest - for _ in range(number_rooms): room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok ) - room_ids.append(room_id) - self.get_success( - self.store.store_destination_rooms_entries( - (destination,), room_id, 1234 - ) + self.store.store_destination_rooms_entries((self.dest,), room_id, 1234) ) - return room_ids - def _check_fields(self, content: List[JsonDict]) -> None: """Checks that the expected room attributes are present in content diff --git a/tests/rest/admin/test_jwks.py b/tests/rest/admin/test_jwks.py
index 55b822c4d0..3636ea3415 100644 --- a/tests/rest/admin/test_jwks.py +++ b/tests/rest/admin/test_jwks.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index f378165513..ea5b4e2c12 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py
@@ -1,8 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright 2020 Dirk Klimpel # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -277,8 +275,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual( - "Missing required integer query parameter before_ts", - channel.json_body["error"], + "Missing integer query parameter 'before_ts'", channel.json_body["error"] ) def test_invalid_parameter(self) -> None: @@ -321,7 +318,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( - "Query parameter size_gt must be a positive integer.", + "Query parameter size_gt must be a string representing a positive integer.", channel.json_body["error"], ) diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 67d1db8ff8..773cfa8d6a 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 Callum Brown # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 95ed736451..a511175b99 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 Dirk Klimpel # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -21,7 +20,6 @@ import json import time import urllib.parse -from http import HTTPStatus from typing import List, Optional from unittest.mock import AsyncMock, Mock @@ -1795,83 +1793,6 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id")) self.assertEqual("ж", channel.json_body["rooms"][0].get("name")) - def test_filter_public_rooms(self) -> None: - self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok, is_public=True - ) - self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok, is_public=True - ) - self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok, is_public=False - ) - - response = self.make_request( - "GET", - "/_synapse/admin/v1/rooms", - access_token=self.admin_user_tok, - ) - self.assertEqual(200, response.code, msg=response.json_body) - self.assertEqual(3, response.json_body["total_rooms"]) - self.assertEqual(3, len(response.json_body["rooms"])) - - response = self.make_request( - "GET", - "/_synapse/admin/v1/rooms?public_rooms=true", - access_token=self.admin_user_tok, - ) - self.assertEqual(200, response.code, msg=response.json_body) - self.assertEqual(2, response.json_body["total_rooms"]) - self.assertEqual(2, len(response.json_body["rooms"])) - - response = self.make_request( - "GET", - "/_synapse/admin/v1/rooms?public_rooms=false", - access_token=self.admin_user_tok, - ) - self.assertEqual(200, response.code, msg=response.json_body) - self.assertEqual(1, response.json_body["total_rooms"]) - self.assertEqual(1, len(response.json_body["rooms"])) - - def test_filter_empty_rooms(self) -> None: - self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok, is_public=True - ) - self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok, is_public=True - ) - room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok, is_public=False - ) - self.helper.leave(room_id, self.admin_user, tok=self.admin_user_tok) - - response = self.make_request( - "GET", - "/_synapse/admin/v1/rooms", - access_token=self.admin_user_tok, - ) - self.assertEqual(200, response.code, msg=response.json_body) - self.assertEqual(3, response.json_body["total_rooms"]) - self.assertEqual(3, len(response.json_body["rooms"])) - - response = self.make_request( - "GET", - "/_synapse/admin/v1/rooms?empty_rooms=false", - access_token=self.admin_user_tok, - ) - self.assertEqual(200, response.code, msg=response.json_body) - self.assertEqual(2, response.json_body["total_rooms"]) - self.assertEqual(2, len(response.json_body["rooms"])) - - response = self.make_request( - "GET", - "/_synapse/admin/v1/rooms?empty_rooms=true", - access_token=self.admin_user_tok, - ) - self.assertEqual(200, response.code, msg=response.json_body) - self.assertEqual(1, response.json_body["total_rooms"]) - self.assertEqual(1, len(response.json_body["rooms"])) - def test_single_room(self) -> None: """Test that a single room can be requested correctly""" # Create two test rooms @@ -2268,33 +2189,6 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) - def test_room_message_filter_query_validation(self) -> None: - # Test json validation in (filter) query parameter. - # Does not test the validity of the filter, only the json validation. - - # Check Get with valid json filter parameter, expect 200. - valid_filter_str = '{"types": ["m.room.message"]}' - channel = self.make_request( - "GET", - f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={valid_filter_str}", - access_token=self.admin_user_tok, - ) - - self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) - - # Check Get with invalid json filter parameter, expect 400 NOT_JSON. - invalid_filter_str = "}}}{}" - channel = self.make_request( - "GET", - f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={invalid_filter_str}", - access_token=self.admin_user_tok, - ) - - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body) - self.assertEqual( - channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body - ) - class JoinAliasRoomTestCase(unittest.HomeserverTestCase): servlets = [ @@ -2627,39 +2521,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): else: self.fail("Event %s from events_after not found" % j) - def test_room_event_context_filter_query_validation(self) -> None: - # Test json validation in (filter) query parameter. - # Does not test the validity of the filter, only the json validation. - - # Create a user with room and event_id. - user_id = self.register_user("test", "test") - user_tok = self.login("test", "test") - room_id = self.helper.create_room_as(user_id, tok=user_tok) - event_id = self.helper.send(room_id, "message 1", tok=user_tok)["event_id"] - - # Check Get with valid json filter parameter, expect 200. - valid_filter_str = '{"types": ["m.room.message"]}' - channel = self.make_request( - "GET", - f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={valid_filter_str}", - access_token=self.admin_user_tok, - ) - - self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) - - # Check Get with invalid json filter parameter, expect 400 NOT_JSON. - invalid_filter_str = "}}}{}" - channel = self.make_request( - "GET", - f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={invalid_filter_str}", - access_token=self.admin_user_tok, - ) - - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body) - self.assertEqual( - channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body - ) - class MakeRoomAdminTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 2a1e42bbc8..ce5e3a5c1f 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 Dirk Klimpel # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 5f60e19e56..0b30e8c65f 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py
@@ -1,8 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright 2020 Dirk Klimpel # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 16bb4349f5..61cbac2332 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2018-2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -37,7 +36,6 @@ from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.media.filepath import MediaFilePaths -from synapse.rest import admin from synapse.rest.client import ( devices, login, @@ -504,7 +502,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - f"{self.url}?deactivated=true", + self.url + "?deactivated=true", {}, access_token=self.admin_user_tok, ) @@ -983,56 +981,6 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(1, channel.json_body["total"]) self.assertFalse(channel.json_body["users"][0]["admin"]) - def test_filter_deactivated_users(self) -> None: - """ - Tests whether the various values of the query parameter `deactivated` lead to the - expected result set. - """ - users_url_v3 = self.url.replace("v2", "v3") - - # Register an additional non admin user - user_id = self.register_user("user", "pass", admin=False) - - # Deactivate that user, requesting erasure. - deactivate_account_handler = self.hs.get_deactivate_account_handler() - self.get_success( - deactivate_account_handler.deactivate_account( - user_id, erase_data=True, requester=create_requester(user_id) - ) - ) - - # Query all users - channel = self.make_request( - "GET", - users_url_v3, - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(2, channel.json_body["total"]) - - # Query deactivated users - channel = self.make_request( - "GET", - f"{users_url_v3}?deactivated=true", - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(1, channel.json_body["total"]) - self.assertEqual("@user:test", channel.json_body["users"][0]["name"]) - - # Query non-deactivated users - channel = self.make_request( - "GET", - f"{users_url_v3}?deactivated=false", - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, channel.code, channel.result) - self.assertEqual(1, channel.json_body["total"]) - self.assertEqual("@admin:test", channel.json_body["users"][0]["name"]) - @override_config( { "experimental_features": { @@ -1181,7 +1129,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): # They should appear in the list users API, marked as not erased. channel = self.make_request( "GET", - f"{self.url}?deactivated=true", + self.url + "?deactivated=true", access_token=self.admin_user_tok, ) users = {user["name"]: user for user in channel.json_body["users"]} @@ -1245,7 +1193,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): dir: The direction of ordering to give the server """ - url = f"{self.url}?deactivated=true&" + url = self.url + "?deactivated=true&" if order_by is not None: url += "order_by=%s&" % (order_by,) if dir is not None and dir in ("b", "f"): @@ -5006,86 +4954,3 @@ class AllowCrossSigningReplacementTestCase(unittest.HomeserverTestCase): ) assert timestamp is not None self.assertGreater(timestamp, self.clock.time_msec()) - - -class UserSuspensionTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - admin.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.admin = self.register_user("thomas", "hackme", True) - self.admin_tok = self.login("thomas", "hackme") - - self.bad_user = self.register_user("teresa", "hackme") - self.bad_user_tok = self.login("teresa", "hackme") - - self.store = hs.get_datastores().main - - @override_config({"experimental_features": {"msc3823_account_suspension": True}}) - def test_suspend_user(self) -> None: - # test that suspending user works - channel = self.make_request( - "PUT", - f"/_synapse/admin/v1/suspend/{self.bad_user}", - {"suspend": True}, - access_token=self.admin_tok, - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body, {f"user_{self.bad_user}_suspended": True}) - - res = self.get_success(self.store.get_user_suspended_status(self.bad_user)) - self.assertEqual(True, res) - - # test that un-suspending user works - channel2 = self.make_request( - "PUT", - f"/_synapse/admin/v1/suspend/{self.bad_user}", - {"suspend": False}, - access_token=self.admin_tok, - ) - self.assertEqual(channel2.code, 200) - self.assertEqual(channel2.json_body, {f"user_{self.bad_user}_suspended": False}) - - res2 = self.get_success(self.store.get_user_suspended_status(self.bad_user)) - self.assertEqual(False, res2) - - # test that trying to un-suspend user who isn't suspended doesn't cause problems - channel3 = self.make_request( - "PUT", - f"/_synapse/admin/v1/suspend/{self.bad_user}", - {"suspend": False}, - access_token=self.admin_tok, - ) - self.assertEqual(channel3.code, 200) - self.assertEqual(channel3.json_body, {f"user_{self.bad_user}_suspended": False}) - - res3 = self.get_success(self.store.get_user_suspended_status(self.bad_user)) - self.assertEqual(False, res3) - - # test that trying to suspend user who is already suspended doesn't cause problems - channel4 = self.make_request( - "PUT", - f"/_synapse/admin/v1/suspend/{self.bad_user}", - {"suspend": True}, - access_token=self.admin_tok, - ) - self.assertEqual(channel4.code, 200) - self.assertEqual(channel4.json_body, {f"user_{self.bad_user}_suspended": True}) - - res4 = self.get_success(self.store.get_user_suspended_status(self.bad_user)) - self.assertEqual(True, res4) - - channel5 = self.make_request( - "PUT", - f"/_synapse/admin/v1/suspend/{self.bad_user}", - {"suspend": True}, - access_token=self.admin_tok, - ) - self.assertEqual(channel5.code, 200) - self.assertEqual(channel5.json_body, {f"user_{self.bad_user}_suspended": True}) - - res5 = self.get_success(self.store.get_user_suspended_status(self.bad_user)) - self.assertEqual(True, res5) diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 4dd5de33d3..d302be33a3 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/__init__.py b/tests/rest/client/__init__.py
index 6a72062b0c..3d833a2e44 100644 --- a/tests/rest/client/__init__.py +++ b/tests/rest/client/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/sliding_sync/__init__.py b/tests/rest/client/sliding_sync/__init__.py deleted file mode 100644
index c4de9d53e2..0000000000 --- a/tests/rest/client/sliding_sync/__init__.py +++ /dev/null
@@ -1,13 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# diff --git a/tests/rest/client/sliding_sync/test_connection_tracking.py b/tests/rest/client/sliding_sync/test_connection_tracking.py deleted file mode 100644
index 4d8866b30a..0000000000 --- a/tests/rest/client/sliding_sync/test_connection_tracking.py +++ /dev/null
@@ -1,453 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging - -from parameterized import parameterized - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.api.constants import EventTypes -from synapse.rest.client import login, room, sync -from synapse.server import HomeServer -from synapse.types import SlidingSyncStreamToken -from synapse.types.handlers import SlidingSyncConfig -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase - -logger = logging.getLogger(__name__) - - -class SlidingSyncConnectionTrackingTestCase(SlidingSyncBase): - """ - Test connection tracking in the Sliding Sync API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - - def test_rooms_required_state_incremental_sync_LIVE(self) -> None: - """Test that we only get state updates in incremental sync for rooms - we've already seen (LIVE). - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Create, ""], - [EventTypes.RoomHistoryVisibility, ""], - # This one doesn't exist in the room - [EventTypes.Name, ""], - ], - "timeline_limit": 0, - } - } - } - - response_body, from_token = self.do_sync(sync_body, tok=user1_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Create, "")], - state_map[(EventTypes.RoomHistoryVisibility, "")], - }, - exact=True, - ) - - # Send a state event - self.helper.send_state( - room_id1, EventTypes.Name, body={"name": "foo"}, tok=user2_tok - ) - - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - self.assertNotIn("initial", response_body["rooms"][room_id1]) - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Name, "")], - }, - exact=True, - ) - - @parameterized.expand([(False,), (True,)]) - def test_rooms_timeline_incremental_sync_PREVIOUSLY(self, limited: bool) -> None: - """ - Test getting room data where we have previously sent down the room, but - we missed sending down some timeline events previously and so its status - is considered PREVIOUSLY. - - There are two versions of this test, one where there are more messages - than the timeline limit, and one where there isn't. - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) - - self.helper.send(room_id1, "msg", tok=user1_tok) - - timeline_limit = 5 - conn_id = "conn_id" - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 0]], - "required_state": [], - "timeline_limit": timeline_limit, - } - }, - "conn_id": "conn_id", - } - - # The first room gets sent down the initial sync - response_body, initial_from_token = self.do_sync(sync_body, tok=user1_tok) - self.assertCountEqual( - response_body["rooms"].keys(), {room_id1}, response_body["rooms"] - ) - - # We now send down some events in room1 (depending on the test param). - expected_events = [] # The set of events in the timeline - if limited: - for _ in range(10): - resp = self.helper.send(room_id1, "msg1", tok=user1_tok) - expected_events.append(resp["event_id"]) - else: - resp = self.helper.send(room_id1, "msg1", tok=user1_tok) - expected_events.append(resp["event_id"]) - - # A second messages happens in the other room, so room1 won't get sent down. - self.helper.send(room_id2, "msg", tok=user1_tok) - - # Only the second room gets sent down sync. - response_body, from_token = self.do_sync( - sync_body, since=initial_from_token, tok=user1_tok - ) - - self.assertCountEqual( - response_body["rooms"].keys(), {room_id2}, response_body["rooms"] - ) - - # FIXME: This is a hack to record that the first room wasn't sent down - # sync, as we don't implement that currently. - sliding_sync_handler = self.hs.get_sliding_sync_handler() - requester = self.get_success( - self.hs.get_auth().get_user_by_access_token(user1_tok) - ) - sync_config = SlidingSyncConfig( - user=requester.user, - requester=requester, - conn_id=conn_id, - ) - - parsed_initial_from_token = self.get_success( - SlidingSyncStreamToken.from_string(self.store, initial_from_token) - ) - connection_position = self.get_success( - sliding_sync_handler.connection_store.record_rooms( - sync_config, - parsed_initial_from_token, - sent_room_ids=[], - unsent_room_ids=[room_id1], - ) - ) - - # FIXME: Now fix up `from_token` with new connect position above. - parsed_from_token = self.get_success( - SlidingSyncStreamToken.from_string(self.store, from_token) - ) - parsed_from_token = SlidingSyncStreamToken( - stream_token=parsed_from_token.stream_token, - connection_position=connection_position, - ) - from_token = self.get_success(parsed_from_token.to_string(self.store)) - - # We now send another event to room1, so we should sync all the missing events. - resp = self.helper.send(room_id1, "msg2", tok=user1_tok) - expected_events.append(resp["event_id"]) - - # This sync should contain the messages from room1 not yet sent down. - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - self.assertCountEqual( - response_body["rooms"].keys(), {room_id1}, response_body["rooms"] - ) - self.assertNotIn("initial", response_body["rooms"][room_id1]) - - self.assertEqual( - [ev["event_id"] for ev in response_body["rooms"][room_id1]["timeline"]], - expected_events[-timeline_limit:], - ) - self.assertEqual(response_body["rooms"][room_id1]["limited"], limited) - self.assertEqual(response_body["rooms"][room_id1].get("required_state"), None) - - def test_rooms_required_state_incremental_sync_PREVIOUSLY(self) -> None: - """ - Test getting room data where we have previously sent down the room, but - we missed sending down some state previously and so its status is - considered PREVIOUSLY. - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) - - self.helper.send(room_id1, "msg", tok=user1_tok) - - conn_id = "conn_id" - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 0]], - "required_state": [ - [EventTypes.Create, ""], - [EventTypes.RoomHistoryVisibility, ""], - # This one doesn't exist in the room - [EventTypes.Name, ""], - ], - "timeline_limit": 0, - } - }, - "conn_id": "conn_id", - } - - # The first room gets sent down the initial sync - response_body, initial_from_token = self.do_sync(sync_body, tok=user1_tok) - self.assertCountEqual( - response_body["rooms"].keys(), {room_id1}, response_body["rooms"] - ) - - # We now send down some state in room1 - resp = self.helper.send_state( - room_id1, EventTypes.Name, {"name": "foo"}, tok=user1_tok - ) - name_change_id = resp["event_id"] - - # A second messages happens in the other room, so room1 won't get sent down. - self.helper.send(room_id2, "msg", tok=user1_tok) - - # Only the second room gets sent down sync. - response_body, from_token = self.do_sync( - sync_body, since=initial_from_token, tok=user1_tok - ) - - self.assertCountEqual( - response_body["rooms"].keys(), {room_id2}, response_body["rooms"] - ) - - # FIXME: This is a hack to record that the first room wasn't sent down - # sync, as we don't implement that currently. - sliding_sync_handler = self.hs.get_sliding_sync_handler() - requester = self.get_success( - self.hs.get_auth().get_user_by_access_token(user1_tok) - ) - sync_config = SlidingSyncConfig( - user=requester.user, - requester=requester, - conn_id=conn_id, - ) - - parsed_initial_from_token = self.get_success( - SlidingSyncStreamToken.from_string(self.store, initial_from_token) - ) - connection_position = self.get_success( - sliding_sync_handler.connection_store.record_rooms( - sync_config, - parsed_initial_from_token, - sent_room_ids=[], - unsent_room_ids=[room_id1], - ) - ) - - # FIXME: Now fix up `from_token` with new connect position above. - parsed_from_token = self.get_success( - SlidingSyncStreamToken.from_string(self.store, from_token) - ) - parsed_from_token = SlidingSyncStreamToken( - stream_token=parsed_from_token.stream_token, - connection_position=connection_position, - ) - from_token = self.get_success(parsed_from_token.to_string(self.store)) - - # We now send another event to room1, so we should sync all the missing state. - self.helper.send(room_id1, "msg", tok=user1_tok) - - # This sync should contain the state changes from room1. - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - self.assertCountEqual( - response_body["rooms"].keys(), {room_id1}, response_body["rooms"] - ) - self.assertNotIn("initial", response_body["rooms"][room_id1]) - - # We should only see the name change. - self.assertEqual( - [ - ev["event_id"] - for ev in response_body["rooms"][room_id1]["required_state"] - ], - [name_change_id], - ) - - def test_rooms_required_state_incremental_sync_NEVER(self) -> None: - """ - Test getting `required_state` where we have NEVER sent down the room before - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) - - self.helper.send(room_id1, "msg", tok=user1_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 0]], - "required_state": [ - [EventTypes.Create, ""], - [EventTypes.RoomHistoryVisibility, ""], - # This one doesn't exist in the room - [EventTypes.Name, ""], - ], - "timeline_limit": 1, - } - }, - } - - # A message happens in the other room, so room1 won't get sent down. - self.helper.send(room_id2, "msg", tok=user1_tok) - - # Only the second room gets sent down sync. - response_body, from_token = self.do_sync(sync_body, tok=user1_tok) - - self.assertCountEqual( - response_body["rooms"].keys(), {room_id2}, response_body["rooms"] - ) - - # We now send another event to room1, so we should send down the full - # room. - self.helper.send(room_id1, "msg2", tok=user1_tok) - - # This sync should contain the messages from room1 not yet sent down. - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - self.assertCountEqual( - response_body["rooms"].keys(), {room_id1}, response_body["rooms"] - ) - - self.assertEqual(response_body["rooms"][room_id1]["initial"], True) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Create, "")], - state_map[(EventTypes.RoomHistoryVisibility, "")], - }, - exact=True, - ) - - def test_rooms_timeline_incremental_sync_NEVER(self) -> None: - """ - Test getting timeline room data where we have NEVER sent down the room - before - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 0]], - "required_state": [], - "timeline_limit": 5, - } - }, - } - - expected_events = [] - for _ in range(4): - resp = self.helper.send(room_id1, "msg", tok=user1_tok) - expected_events.append(resp["event_id"]) - - # A message happens in the other room, so room1 won't get sent down. - self.helper.send(room_id2, "msg", tok=user1_tok) - - # Only the second room gets sent down sync. - response_body, from_token = self.do_sync(sync_body, tok=user1_tok) - - self.assertCountEqual( - response_body["rooms"].keys(), {room_id2}, response_body["rooms"] - ) - - # We now send another event to room1 so it comes down sync - resp = self.helper.send(room_id1, "msg2", tok=user1_tok) - expected_events.append(resp["event_id"]) - - # This sync should contain the messages from room1 not yet sent down. - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - self.assertCountEqual( - response_body["rooms"].keys(), {room_id1}, response_body["rooms"] - ) - - self.assertEqual( - [ev["event_id"] for ev in response_body["rooms"][room_id1]["timeline"]], - expected_events, - ) - self.assertEqual(response_body["rooms"][room_id1]["limited"], True) - self.assertEqual(response_body["rooms"][room_id1]["initial"], True) diff --git a/tests/rest/client/sliding_sync/test_extension_account_data.py b/tests/rest/client/sliding_sync/test_extension_account_data.py deleted file mode 100644
index 3482a5f887..0000000000 --- a/tests/rest/client/sliding_sync/test_extension_account_data.py +++ /dev/null
@@ -1,495 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.api.constants import AccountDataTypes -from synapse.rest.client import login, room, sendtodevice, sync -from synapse.server import HomeServer -from synapse.types import StreamKeyType -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase -from tests.server import TimedOutException - -logger = logging.getLogger(__name__) - - -class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase): - """Tests for the account_data sliding sync extension""" - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - sendtodevice.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.account_data_handler = hs.get_account_data_handler() - - def test_no_data_initial_sync(self) -> None: - """ - Test that enabling the account_data extension works during an intitial sync, - even if there is no-data. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Make an initial Sliding Sync request with the account_data extension enabled - sync_body = { - "lists": {}, - "extensions": { - "account_data": { - "enabled": True, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - self.assertIncludes( - { - global_event["type"] - for global_event in response_body["extensions"]["account_data"].get( - "global" - ) - }, - # Even though we don't have any global account data set, Synapse saves some - # default push rules for us. - {AccountDataTypes.PUSH_RULES}, - exact=True, - ) - self.assertIncludes( - response_body["extensions"]["account_data"].get("rooms").keys(), - set(), - exact=True, - ) - - def test_no_data_incremental_sync(self) -> None: - """ - Test that enabling account_data extension works during an incremental sync, even - if there is no-data. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - sync_body = { - "lists": {}, - "extensions": { - "account_data": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make an incremental Sliding Sync request with the account_data extension enabled - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # There has been no account data changes since the `from_token` so we shouldn't - # see any account data here. - self.assertIncludes( - { - global_event["type"] - for global_event in response_body["extensions"]["account_data"].get( - "global" - ) - }, - set(), - exact=True, - ) - self.assertIncludes( - response_body["extensions"]["account_data"].get("rooms").keys(), - set(), - exact=True, - ) - - def test_global_account_data_initial_sync(self) -> None: - """ - On initial sync, we should return all global account data on initial sync. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Update the global account data - self.get_success( - self.account_data_handler.add_account_data_for_user( - user_id=user1_id, - account_data_type="org.matrix.foobarbaz", - content={"foo": "bar"}, - ) - ) - - # Make an initial Sliding Sync request with the account_data extension enabled - sync_body = { - "lists": {}, - "extensions": { - "account_data": { - "enabled": True, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # It should show us all of the global account data - self.assertIncludes( - { - global_event["type"] - for global_event in response_body["extensions"]["account_data"].get( - "global" - ) - }, - {AccountDataTypes.PUSH_RULES, "org.matrix.foobarbaz"}, - exact=True, - ) - self.assertIncludes( - response_body["extensions"]["account_data"].get("rooms").keys(), - set(), - exact=True, - ) - - def test_global_account_data_incremental_sync(self) -> None: - """ - On incremental sync, we should only account data that has changed since the - `from_token`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Add some global account data - self.get_success( - self.account_data_handler.add_account_data_for_user( - user_id=user1_id, - account_data_type="org.matrix.foobarbaz", - content={"foo": "bar"}, - ) - ) - - sync_body = { - "lists": {}, - "extensions": { - "account_data": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Add some other global account data - self.get_success( - self.account_data_handler.add_account_data_for_user( - user_id=user1_id, - account_data_type="org.matrix.doodardaz", - content={"doo": "dar"}, - ) - ) - - # Make an incremental Sliding Sync request with the account_data extension enabled - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - self.assertIncludes( - { - global_event["type"] - for global_event in response_body["extensions"]["account_data"].get( - "global" - ) - }, - # We should only see the new global account data that happened after the `from_token` - {"org.matrix.doodardaz"}, - exact=True, - ) - self.assertIncludes( - response_body["extensions"]["account_data"].get("rooms").keys(), - set(), - exact=True, - ) - - def test_room_account_data_initial_sync(self) -> None: - """ - On initial sync, we return all account data for a given room but only for - rooms that we request and are being returned in the Sliding Sync response. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a room and add some room account data - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - self.get_success( - self.account_data_handler.add_account_data_to_room( - user_id=user1_id, - room_id=room_id1, - account_data_type="org.matrix.roorarraz", - content={"roo": "rar"}, - ) - ) - - # Create another room with some room account data - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) - self.get_success( - self.account_data_handler.add_account_data_to_room( - user_id=user1_id, - room_id=room_id2, - account_data_type="org.matrix.roorarraz", - content={"roo": "rar"}, - ) - ) - - # Make an initial Sliding Sync request with the account_data extension enabled - sync_body = { - "lists": {}, - "room_subscriptions": { - room_id1: { - "required_state": [], - "timeline_limit": 0, - } - }, - "extensions": { - "account_data": { - "enabled": True, - "rooms": [room_id1, room_id2], - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - self.assertIsNotNone(response_body["extensions"]["account_data"].get("global")) - # Even though we requested room2, we only expect room1 to show up because that's - # the only room in the Sliding Sync response (room2 is not one of our room - # subscriptions or in a sliding window list). - self.assertIncludes( - response_body["extensions"]["account_data"].get("rooms").keys(), - {room_id1}, - exact=True, - ) - self.assertIncludes( - { - event["type"] - for event in response_body["extensions"]["account_data"] - .get("rooms") - .get(room_id1) - }, - {"org.matrix.roorarraz"}, - exact=True, - ) - - def test_room_account_data_incremental_sync(self) -> None: - """ - On incremental sync, we return all account data for a given room but only for - rooms that we request and are being returned in the Sliding Sync response. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a room and add some room account data - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - self.get_success( - self.account_data_handler.add_account_data_to_room( - user_id=user1_id, - room_id=room_id1, - account_data_type="org.matrix.roorarraz", - content={"roo": "rar"}, - ) - ) - - # Create another room with some room account data - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) - self.get_success( - self.account_data_handler.add_account_data_to_room( - user_id=user1_id, - room_id=room_id2, - account_data_type="org.matrix.roorarraz", - content={"roo": "rar"}, - ) - ) - - sync_body = { - "lists": {}, - "room_subscriptions": { - room_id1: { - "required_state": [], - "timeline_limit": 0, - } - }, - "extensions": { - "account_data": { - "enabled": True, - "rooms": [room_id1, room_id2], - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Add some other room account data - self.get_success( - self.account_data_handler.add_account_data_to_room( - user_id=user1_id, - room_id=room_id1, - account_data_type="org.matrix.roorarraz2", - content={"roo": "rar"}, - ) - ) - self.get_success( - self.account_data_handler.add_account_data_to_room( - user_id=user1_id, - room_id=room_id2, - account_data_type="org.matrix.roorarraz2", - content={"roo": "rar"}, - ) - ) - - # Make an incremental Sliding Sync request with the account_data extension enabled - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - self.assertIsNotNone(response_body["extensions"]["account_data"].get("global")) - # Even though we requested room2, we only expect room1 to show up because that's - # the only room in the Sliding Sync response (room2 is not one of our room - # subscriptions or in a sliding window list). - self.assertIncludes( - response_body["extensions"]["account_data"].get("rooms").keys(), - {room_id1}, - exact=True, - ) - # We should only see the new room account data that happened after the `from_token` - self.assertIncludes( - { - event["type"] - for event in response_body["extensions"]["account_data"] - .get("rooms") - .get(room_id1) - }, - {"org.matrix.roorarraz2"}, - exact=True, - ) - - def test_wait_for_new_data(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive. - - (Only applies to incremental syncs with a `timeout` specified) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id, user1_id, tok=user1_tok) - - sync_body = { - "lists": {}, - "extensions": { - "account_data": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make an incremental Sliding Sync request with the account_data extension enabled - channel = self.make_request( - "POST", - self.sync_endpoint + f"?timeout=10000&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Bump the global account data to trigger new results - self.get_success( - self.account_data_handler.add_account_data_for_user( - user1_id, - "org.matrix.foobarbaz", - {"foo": "bar"}, - ) - ) - # Should respond before the 10 second timeout - channel.await_result(timeout_ms=3000) - self.assertEqual(channel.code, 200, channel.json_body) - - # We should see the global account data update - self.assertIncludes( - { - global_event["type"] - for global_event in channel.json_body["extensions"]["account_data"].get( - "global" - ) - }, - {"org.matrix.foobarbaz"}, - exact=True, - ) - self.assertIncludes( - channel.json_body["extensions"]["account_data"].get("rooms").keys(), - set(), - exact=True, - ) - - def test_wait_for_new_data_timeout(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive but - no data ever arrives so we timeout. We're also making sure that the default data - from the account_data extension doesn't trigger a false-positive for new data. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - sync_body = { - "lists": {}, - "extensions": { - "account_data": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make the Sliding Sync request - channel = self.make_request( - "POST", - self.sync_endpoint + f"?timeout=10000&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Wake-up `notifier.wait_for_events(...)` that will cause us test - # `SlidingSyncResult.__bool__` for new results. - self._bump_notifier_wait_for_events( - user1_id, - # We choose `StreamKeyType.PRESENCE` because we're testing for account data - # and don't want to contaminate the account data results using - # `StreamKeyType.ACCOUNT_DATA`. - wake_stream_key=StreamKeyType.PRESENCE, - ) - # Block for a little bit more to ensure we don't see any new results. - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=4000) - # Wait for the sync to complete (wait for the rest of the 10 second timeout, - # 5000 + 4000 + 1200 > 10000) - channel.await_result(timeout_ms=1200) - self.assertEqual(channel.code, 200, channel.json_body) - - self.assertIsNotNone( - channel.json_body["extensions"]["account_data"].get("global") - ) - self.assertIsNotNone( - channel.json_body["extensions"]["account_data"].get("rooms") - ) diff --git a/tests/rest/client/sliding_sync/test_extension_e2ee.py b/tests/rest/client/sliding_sync/test_extension_e2ee.py deleted file mode 100644
index 320f8c788f..0000000000 --- a/tests/rest/client/sliding_sync/test_extension_e2ee.py +++ /dev/null
@@ -1,441 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.rest.client import devices, login, room, sync -from synapse.server import HomeServer -from synapse.types import JsonDict, StreamKeyType -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase -from tests.server import TimedOutException - -logger = logging.getLogger(__name__) - - -class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase): - """Tests for the e2ee sliding sync extension""" - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - devices.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.e2e_keys_handler = hs.get_e2e_keys_handler() - - def test_no_data_initial_sync(self) -> None: - """ - Test that enabling e2ee extension works during an intitial sync, even if there - is no-data - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Make an initial Sliding Sync request with the e2ee extension enabled - sync_body = { - "lists": {}, - "extensions": { - "e2ee": { - "enabled": True, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Device list updates are only present for incremental syncs - self.assertIsNone(response_body["extensions"]["e2ee"].get("device_lists")) - - # Both of these should be present even when empty - self.assertEqual( - response_body["extensions"]["e2ee"]["device_one_time_keys_count"], - { - # This is always present because of - # https://github.com/element-hq/element-android/issues/3725 and - # https://github.com/matrix-org/synapse/issues/10456 - "signed_curve25519": 0 - }, - ) - self.assertEqual( - response_body["extensions"]["e2ee"]["device_unused_fallback_key_types"], - [], - ) - - def test_no_data_incremental_sync(self) -> None: - """ - Test that enabling e2ee extension works during an incremental sync, even if - there is no-data - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - sync_body = { - "lists": {}, - "extensions": { - "e2ee": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make an incremental Sliding Sync request with the e2ee extension enabled - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # Device list shows up for incremental syncs - self.assertEqual( - response_body["extensions"]["e2ee"].get("device_lists", {}).get("changed"), - [], - ) - self.assertEqual( - response_body["extensions"]["e2ee"].get("device_lists", {}).get("left"), - [], - ) - - # Both of these should be present even when empty - self.assertEqual( - response_body["extensions"]["e2ee"]["device_one_time_keys_count"], - { - # Note that "signed_curve25519" is always returned in key count responses - # regardless of whether we uploaded any keys for it. This is necessary until - # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. - # - # Also related: - # https://github.com/element-hq/element-android/issues/3725 and - # https://github.com/matrix-org/synapse/issues/10456 - "signed_curve25519": 0 - }, - ) - self.assertEqual( - response_body["extensions"]["e2ee"]["device_unused_fallback_key_types"], - [], - ) - - def test_wait_for_new_data(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive. - - (Only applies to incremental syncs with a `timeout` specified) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - test_device_id = "TESTDEVICE" - user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass", device_id=test_device_id) - - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id, user1_id, tok=user1_tok) - self.helper.join(room_id, user3_id, tok=user3_tok) - - sync_body = { - "lists": {}, - "extensions": { - "e2ee": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make the Sliding Sync request - channel = self.make_request( - "POST", - self.sync_endpoint + "?timeout=10000" + f"&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Bump the device lists to trigger new results - # Have user3 update their device list - device_update_channel = self.make_request( - "PUT", - f"/devices/{test_device_id}", - { - "display_name": "New Device Name", - }, - access_token=user3_tok, - ) - self.assertEqual( - device_update_channel.code, 200, device_update_channel.json_body - ) - # Should respond before the 10 second timeout - channel.await_result(timeout_ms=3000) - self.assertEqual(channel.code, 200, channel.json_body) - - # We should see the device list update - self.assertEqual( - channel.json_body["extensions"]["e2ee"] - .get("device_lists", {}) - .get("changed"), - [user3_id], - ) - self.assertEqual( - channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"), - [], - ) - - def test_wait_for_new_data_timeout(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive but - no data ever arrives so we timeout. We're also making sure that the default data - from the E2EE extension doesn't trigger a false-positive for new data (see - `device_one_time_keys_count.signed_curve25519`). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - sync_body = { - "lists": {}, - "extensions": { - "e2ee": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make the Sliding Sync request - channel = self.make_request( - "POST", - self.sync_endpoint + f"?timeout=10000&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Wake-up `notifier.wait_for_events(...)` that will cause us test - # `SlidingSyncResult.__bool__` for new results. - self._bump_notifier_wait_for_events( - user1_id, wake_stream_key=StreamKeyType.ACCOUNT_DATA - ) - # Block for a little bit more to ensure we don't see any new results. - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=4000) - # Wait for the sync to complete (wait for the rest of the 10 second timeout, - # 5000 + 4000 + 1200 > 10000) - channel.await_result(timeout_ms=1200) - self.assertEqual(channel.code, 200, channel.json_body) - - # Device lists are present for incremental syncs but empty because no device changes - self.assertEqual( - channel.json_body["extensions"]["e2ee"] - .get("device_lists", {}) - .get("changed"), - [], - ) - self.assertEqual( - channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"), - [], - ) - - # Both of these should be present even when empty - self.assertEqual( - channel.json_body["extensions"]["e2ee"]["device_one_time_keys_count"], - { - # Note that "signed_curve25519" is always returned in key count responses - # regardless of whether we uploaded any keys for it. This is necessary until - # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. - # - # Also related: - # https://github.com/element-hq/element-android/issues/3725 and - # https://github.com/matrix-org/synapse/issues/10456 - "signed_curve25519": 0 - }, - ) - self.assertEqual( - channel.json_body["extensions"]["e2ee"]["device_unused_fallback_key_types"], - [], - ) - - def test_device_lists(self) -> None: - """ - Test that device list updates are included in the response - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - test_device_id = "TESTDEVICE" - user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass", device_id=test_device_id) - - user4_id = self.register_user("user4", "pass") - user4_tok = self.login(user4_id, "pass") - - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id, user1_id, tok=user1_tok) - self.helper.join(room_id, user3_id, tok=user3_tok) - self.helper.join(room_id, user4_id, tok=user4_tok) - - sync_body = { - "lists": {}, - "extensions": { - "e2ee": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Have user3 update their device list - channel = self.make_request( - "PUT", - f"/devices/{test_device_id}", - { - "display_name": "New Device Name", - }, - access_token=user3_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # User4 leaves the room - self.helper.leave(room_id, user4_id, tok=user4_tok) - - # Make an incremental Sliding Sync request with the e2ee extension enabled - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # Device list updates show up - self.assertEqual( - response_body["extensions"]["e2ee"].get("device_lists", {}).get("changed"), - [user3_id], - ) - self.assertEqual( - response_body["extensions"]["e2ee"].get("device_lists", {}).get("left"), - [user4_id], - ) - - def test_device_one_time_keys_count(self) -> None: - """ - Test that `device_one_time_keys_count` are included in the response - """ - test_device_id = "TESTDEVICE" - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass", device_id=test_device_id) - - # Upload one time keys for the user/device - keys: JsonDict = { - "alg1:k1": "key1", - "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, - "alg2:k3": {"key": "key3"}, - } - upload_keys_response = self.get_success( - self.e2e_keys_handler.upload_keys_for_user( - user1_id, test_device_id, {"one_time_keys": keys} - ) - ) - self.assertDictEqual( - upload_keys_response, - { - "one_time_key_counts": { - "alg1": 1, - "alg2": 2, - # Note that "signed_curve25519" is always returned in key count responses - # regardless of whether we uploaded any keys for it. This is necessary until - # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. - # - # Also related: - # https://github.com/element-hq/element-android/issues/3725 and - # https://github.com/matrix-org/synapse/issues/10456 - "signed_curve25519": 0, - } - }, - ) - - # Make a Sliding Sync request with the e2ee extension enabled - sync_body = { - "lists": {}, - "extensions": { - "e2ee": { - "enabled": True, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Check for those one time key counts - self.assertEqual( - response_body["extensions"]["e2ee"].get("device_one_time_keys_count"), - { - "alg1": 1, - "alg2": 2, - # Note that "signed_curve25519" is always returned in key count responses - # regardless of whether we uploaded any keys for it. This is necessary until - # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. - # - # Also related: - # https://github.com/element-hq/element-android/issues/3725 and - # https://github.com/matrix-org/synapse/issues/10456 - "signed_curve25519": 0, - }, - ) - - def test_device_unused_fallback_key_types(self) -> None: - """ - Test that `device_unused_fallback_key_types` are included in the response - """ - test_device_id = "TESTDEVICE" - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass", device_id=test_device_id) - - # We shouldn't have any unused fallback keys yet - res = self.get_success( - self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id) - ) - self.assertEqual(res, []) - - # Upload a fallback key for the user/device - self.get_success( - self.e2e_keys_handler.upload_keys_for_user( - user1_id, - test_device_id, - {"fallback_keys": {"alg1:k1": "fallback_key1"}}, - ) - ) - # We should now have an unused alg1 key - fallback_res = self.get_success( - self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id) - ) - self.assertEqual(fallback_res, ["alg1"], fallback_res) - - # Make a Sliding Sync request with the e2ee extension enabled - sync_body = { - "lists": {}, - "extensions": { - "e2ee": { - "enabled": True, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Check for the unused fallback key types - self.assertListEqual( - response_body["extensions"]["e2ee"].get("device_unused_fallback_key_types"), - ["alg1"], - ) diff --git a/tests/rest/client/sliding_sync/test_extension_receipts.py b/tests/rest/client/sliding_sync/test_extension_receipts.py deleted file mode 100644
index 65fbac260e..0000000000 --- a/tests/rest/client/sliding_sync/test_extension_receipts.py +++ /dev/null
@@ -1,679 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.api.constants import EduTypes, ReceiptTypes -from synapse.rest.client import login, receipts, room, sync -from synapse.server import HomeServer -from synapse.types import StreamKeyType -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase -from tests.server import TimedOutException - -logger = logging.getLogger(__name__) - - -class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase): - """Tests for the receipts sliding sync extension""" - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - receipts.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - - def test_no_data_initial_sync(self) -> None: - """ - Test that enabling the receipts extension works during an intitial sync, - even if there is no-data. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Make an initial Sliding Sync request with the receipts extension enabled - sync_body = { - "lists": {}, - "extensions": { - "receipts": { - "enabled": True, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - self.assertIncludes( - response_body["extensions"]["receipts"].get("rooms").keys(), - set(), - exact=True, - ) - - def test_no_data_incremental_sync(self) -> None: - """ - Test that enabling receipts extension works during an incremental sync, even - if there is no-data. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - sync_body = { - "lists": {}, - "extensions": { - "receipts": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make an incremental Sliding Sync request with the receipts extension enabled - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - self.assertIncludes( - response_body["extensions"]["receipts"].get("rooms").keys(), - set(), - exact=True, - ) - - def test_receipts_initial_sync_with_timeline(self) -> None: - """ - On initial sync, we only return receipts for events in a given room's timeline. - - We also make sure that we only return receipts for rooms that we request and are - already being returned in the Sliding Sync response. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass") - user4_id = self.register_user("user4", "pass") - user4_tok = self.login(user4_id, "pass") - - # Create a room - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.join(room_id1, user3_id, tok=user3_tok) - self.helper.join(room_id1, user4_id, tok=user4_tok) - room1_event_response1 = self.helper.send( - room_id1, body="new event1", tok=user2_tok - ) - room1_event_response2 = self.helper.send( - room_id1, body="new event2", tok=user2_tok - ) - # User1 reads the last event - channel = self.make_request( - "POST", - f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response2['event_id']}", - {}, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User2 reads the last event - channel = self.make_request( - "POST", - f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response2['event_id']}", - {}, - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User3 reads the first event - channel = self.make_request( - "POST", - f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response1['event_id']}", - {}, - access_token=user3_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User4 privately reads the last event (make sure this doesn't leak to the other users) - channel = self.make_request( - "POST", - f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ_PRIVATE}/{room1_event_response2['event_id']}", - {}, - access_token=user4_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Create another room - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id2, user1_id, tok=user1_tok) - self.helper.join(room_id2, user3_id, tok=user3_tok) - self.helper.join(room_id2, user4_id, tok=user4_tok) - room2_event_response1 = self.helper.send( - room_id2, body="new event2", tok=user2_tok - ) - # User1 reads the last event - channel = self.make_request( - "POST", - f"/rooms/{room_id2}/receipt/{ReceiptTypes.READ}/{room2_event_response1['event_id']}", - {}, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User2 reads the last event - channel = self.make_request( - "POST", - f"/rooms/{room_id2}/receipt/{ReceiptTypes.READ}/{room2_event_response1['event_id']}", - {}, - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User4 privately reads the last event (make sure this doesn't leak to the other users) - channel = self.make_request( - "POST", - f"/rooms/{room_id2}/receipt/{ReceiptTypes.READ_PRIVATE}/{room2_event_response1['event_id']}", - {}, - access_token=user4_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Make an initial Sliding Sync request with the receipts extension enabled - sync_body = { - "lists": {}, - "room_subscriptions": { - room_id1: { - "required_state": [], - # On initial sync, we only have receipts for events in the timeline - "timeline_limit": 1, - } - }, - "extensions": { - "receipts": { - "enabled": True, - "rooms": [room_id1, room_id2], - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Only the latest event in the room is in the timelie because the `timeline_limit` is 1 - self.assertIncludes( - { - event["event_id"] - for event in response_body["rooms"][room_id1].get("timeline", []) - }, - {room1_event_response2["event_id"]}, - exact=True, - message=str(response_body["rooms"][room_id1]), - ) - - # Even though we requested room2, we only expect room1 to show up because that's - # the only room in the Sliding Sync response (room2 is not one of our room - # subscriptions or in a sliding window list). - self.assertIncludes( - response_body["extensions"]["receipts"].get("rooms").keys(), - {room_id1}, - exact=True, - ) - # Sanity check that it's the correct ephemeral event type - self.assertEqual( - response_body["extensions"]["receipts"]["rooms"][room_id1]["type"], - EduTypes.RECEIPT, - ) - # We can see user1 and user2 read receipts - self.assertIncludes( - response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][ - room1_event_response2["event_id"] - ][ReceiptTypes.READ].keys(), - {user1_id, user2_id}, - exact=True, - ) - # User1 did not have a private read receipt and we shouldn't leak others' - # private read receipts - self.assertIncludes( - response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][ - room1_event_response2["event_id"] - ] - .get(ReceiptTypes.READ_PRIVATE, {}) - .keys(), - set(), - exact=True, - ) - - # We shouldn't see receipts for event2 since it wasn't in the timeline and this is an initial sync - self.assertIsNone( - response_body["extensions"]["receipts"]["rooms"][room_id1]["content"].get( - room1_event_response1["event_id"] - ) - ) - - def test_receipts_incremental_sync(self) -> None: - """ - On incremental sync, we return all receipts in the token range for a given room - but only for rooms that we request and are being returned in the Sliding Sync - response. - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass") - - # Create room1 - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.join(room_id1, user3_id, tok=user3_tok) - room1_event_response1 = self.helper.send( - room_id1, body="new event2", tok=user2_tok - ) - # User2 reads the last event (before the `from_token`) - channel = self.make_request( - "POST", - f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response1['event_id']}", - {}, - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Create room2 - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id2, user1_id, tok=user1_tok) - room2_event_response1 = self.helper.send( - room_id2, body="new event2", tok=user2_tok - ) - # User1 reads the last event (before the `from_token`) - channel = self.make_request( - "POST", - f"/rooms/{room_id2}/receipt/{ReceiptTypes.READ}/{room2_event_response1['event_id']}", - {}, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Create room3 - room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id3, user1_id, tok=user1_tok) - self.helper.join(room_id3, user3_id, tok=user3_tok) - room3_event_response1 = self.helper.send( - room_id3, body="new event", tok=user2_tok - ) - - # Create room4 - room_id4 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id4, user1_id, tok=user1_tok) - self.helper.join(room_id4, user3_id, tok=user3_tok) - event_response4 = self.helper.send(room_id4, body="new event", tok=user2_tok) - # User1 reads the last event (before the `from_token`) - channel = self.make_request( - "POST", - f"/rooms/{room_id4}/receipt/{ReceiptTypes.READ}/{event_response4['event_id']}", - {}, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - sync_body = { - "lists": {}, - "room_subscriptions": { - room_id1: { - "required_state": [], - "timeline_limit": 0, - }, - room_id3: { - "required_state": [], - "timeline_limit": 0, - }, - room_id4: { - "required_state": [], - "timeline_limit": 0, - }, - }, - "extensions": { - "receipts": { - "enabled": True, - "rooms": [room_id1, room_id2, room_id3, room_id4], - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Add some more read receipts after the `from_token` - # - # User1 reads room1 - channel = self.make_request( - "POST", - f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response1['event_id']}", - {}, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User1 privately reads room2 - channel = self.make_request( - "POST", - f"/rooms/{room_id2}/receipt/{ReceiptTypes.READ_PRIVATE}/{room2_event_response1['event_id']}", - {}, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User3 reads room3 - channel = self.make_request( - "POST", - f"/rooms/{room_id3}/receipt/{ReceiptTypes.READ}/{room3_event_response1['event_id']}", - {}, - access_token=user3_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # No activity for room4 after the `from_token` - - # Make an incremental Sliding Sync request with the receipts extension enabled - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # Even though we requested room2, we only expect rooms to show up if they are - # already in the Sliding Sync response. room4 doesn't show up because there is - # no activity after the `from_token`. - self.assertIncludes( - response_body["extensions"]["receipts"].get("rooms").keys(), - {room_id1, room_id3}, - exact=True, - ) - - # Check room1: - # - # Sanity check that it's the correct ephemeral event type - self.assertEqual( - response_body["extensions"]["receipts"]["rooms"][room_id1]["type"], - EduTypes.RECEIPT, - ) - # We only see that user1 has read something in room1 since the `from_token` - self.assertIncludes( - response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][ - room1_event_response1["event_id"] - ][ReceiptTypes.READ].keys(), - {user1_id}, - exact=True, - ) - # User1 did not send a private read receipt in this room and we shouldn't leak - # others' private read receipts - self.assertIncludes( - response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][ - room1_event_response1["event_id"] - ] - .get(ReceiptTypes.READ_PRIVATE, {}) - .keys(), - set(), - exact=True, - ) - # No events in the timeline since they were sent before the `from_token` - self.assertNotIn(room_id1, response_body["rooms"]) - - # Check room3: - # - # Sanity check that it's the correct ephemeral event type - self.assertEqual( - response_body["extensions"]["receipts"]["rooms"][room_id3]["type"], - EduTypes.RECEIPT, - ) - # We only see that user3 has read something in room1 since the `from_token` - self.assertIncludes( - response_body["extensions"]["receipts"]["rooms"][room_id3]["content"][ - room3_event_response1["event_id"] - ][ReceiptTypes.READ].keys(), - {user3_id}, - exact=True, - ) - # User1 did not send a private read receipt in this room and we shouldn't leak - # others' private read receipts - self.assertIncludes( - response_body["extensions"]["receipts"]["rooms"][room_id3]["content"][ - room3_event_response1["event_id"] - ] - .get(ReceiptTypes.READ_PRIVATE, {}) - .keys(), - set(), - exact=True, - ) - # No events in the timeline since they were sent before the `from_token` - self.assertNotIn(room_id3, response_body["rooms"]) - - def test_receipts_incremental_sync_all_live_receipts(self) -> None: - """ - On incremental sync, we return all receipts in the token range for a given room - even if they are not in the timeline. - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Create room1 - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - sync_body = { - "lists": {}, - "room_subscriptions": { - room_id1: { - "required_state": [], - # The timeline will only include event2 - "timeline_limit": 1, - }, - }, - "extensions": { - "receipts": { - "enabled": True, - "rooms": [room_id1], - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - room1_event_response1 = self.helper.send( - room_id1, body="new event1", tok=user2_tok - ) - room1_event_response2 = self.helper.send( - room_id1, body="new event2", tok=user2_tok - ) - - # User1 reads event1 - channel = self.make_request( - "POST", - f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response1['event_id']}", - {}, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User2 reads event2 - channel = self.make_request( - "POST", - f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response2['event_id']}", - {}, - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Make an incremental Sliding Sync request with the receipts extension enabled - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # We should see room1 because it has receipts in the token range - self.assertIncludes( - response_body["extensions"]["receipts"].get("rooms").keys(), - {room_id1}, - exact=True, - ) - # Sanity check that it's the correct ephemeral event type - self.assertEqual( - response_body["extensions"]["receipts"]["rooms"][room_id1]["type"], - EduTypes.RECEIPT, - ) - # We should see all receipts in the token range regardless of whether the events - # are in the timeline - self.assertIncludes( - response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][ - room1_event_response1["event_id"] - ][ReceiptTypes.READ].keys(), - {user1_id}, - exact=True, - ) - self.assertIncludes( - response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][ - room1_event_response2["event_id"] - ][ReceiptTypes.READ].keys(), - {user2_id}, - exact=True, - ) - # Only the latest event in the timeline because the `timeline_limit` is 1 - self.assertIncludes( - { - event["event_id"] - for event in response_body["rooms"][room_id1].get("timeline", []) - }, - {room1_event_response2["event_id"]}, - exact=True, - message=str(response_body["rooms"][room_id1]), - ) - - def test_wait_for_new_data(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive. - - (Only applies to incremental syncs with a `timeout` specified) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id, user1_id, tok=user1_tok) - event_response = self.helper.send(room_id, body="new event", tok=user2_tok) - - sync_body = { - "lists": {}, - "room_subscriptions": { - room_id: { - "required_state": [], - "timeline_limit": 0, - }, - }, - "extensions": { - "receipts": { - "enabled": True, - "rooms": [room_id], - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make an incremental Sliding Sync request with the receipts extension enabled - channel = self.make_request( - "POST", - self.sync_endpoint + f"?timeout=10000&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Bump the receipts to trigger new results - receipt_channel = self.make_request( - "POST", - f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_response['event_id']}", - {}, - access_token=user2_tok, - ) - self.assertEqual(receipt_channel.code, 200, receipt_channel.json_body) - # Should respond before the 10 second timeout - channel.await_result(timeout_ms=3000) - self.assertEqual(channel.code, 200, channel.json_body) - - # We should see the new receipt - self.assertIncludes( - channel.json_body.get("extensions", {}) - .get("receipts", {}) - .get("rooms", {}) - .keys(), - {room_id}, - exact=True, - message=str(channel.json_body), - ) - self.assertIncludes( - channel.json_body["extensions"]["receipts"]["rooms"][room_id]["content"][ - event_response["event_id"] - ][ReceiptTypes.READ].keys(), - {user2_id}, - exact=True, - ) - # User1 did not send a private read receipt in this room and we shouldn't leak - # others' private read receipts - self.assertIncludes( - channel.json_body["extensions"]["receipts"]["rooms"][room_id]["content"][ - event_response["event_id"] - ] - .get(ReceiptTypes.READ_PRIVATE, {}) - .keys(), - set(), - exact=True, - ) - - def test_wait_for_new_data_timeout(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive but - no data ever arrives so we timeout. We're also making sure that the default data - from the receipts extension doesn't trigger a false-positive for new data. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - sync_body = { - "lists": {}, - "extensions": { - "receipts": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make the Sliding Sync request - channel = self.make_request( - "POST", - self.sync_endpoint + f"?timeout=10000&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Wake-up `notifier.wait_for_events(...)` that will cause us test - # `SlidingSyncResult.__bool__` for new results. - self._bump_notifier_wait_for_events( - user1_id, wake_stream_key=StreamKeyType.ACCOUNT_DATA - ) - # Block for a little bit more to ensure we don't see any new results. - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=4000) - # Wait for the sync to complete (wait for the rest of the 10 second timeout, - # 5000 + 4000 + 1200 > 10000) - channel.await_result(timeout_ms=1200) - self.assertEqual(channel.code, 200, channel.json_body) - - self.assertIncludes( - channel.json_body["extensions"]["receipts"].get("rooms").keys(), - set(), - exact=True, - ) diff --git a/tests/rest/client/sliding_sync/test_extension_to_device.py b/tests/rest/client/sliding_sync/test_extension_to_device.py deleted file mode 100644
index f8500812ea..0000000000 --- a/tests/rest/client/sliding_sync/test_extension_to_device.py +++ /dev/null
@@ -1,278 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging -from typing import List - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.rest.client import login, sendtodevice, sync -from synapse.server import HomeServer -from synapse.types import JsonDict, StreamKeyType -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase -from tests.server import TimedOutException - -logger = logging.getLogger(__name__) - - -class SlidingSyncToDeviceExtensionTestCase(SlidingSyncBase): - """Tests for the to-device sliding sync extension""" - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - sync.register_servlets, - sendtodevice.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - - def _assert_to_device_response( - self, response_body: JsonDict, expected_messages: List[JsonDict] - ) -> str: - """Assert the sliding sync response was successful and has the expected - to-device messages. - - Returns the next_batch token from the to-device section. - """ - extensions = response_body["extensions"] - to_device = extensions["to_device"] - self.assertIsInstance(to_device["next_batch"], str) - self.assertEqual(to_device["events"], expected_messages) - - return to_device["next_batch"] - - def test_no_data(self) -> None: - """Test that enabling to-device extension works, even if there is - no-data - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - sync_body = { - "lists": {}, - "extensions": { - "to_device": { - "enabled": True, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # We expect no to-device messages - self._assert_to_device_response(response_body, []) - - def test_data_initial_sync(self) -> None: - """Test that we get to-device messages when we don't specify a since - token""" - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass", "d1") - user2_id = self.register_user("u2", "pass") - user2_tok = self.login(user2_id, "pass", "d2") - - # Send the to-device message - test_msg = {"foo": "bar"} - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.test/1234", - content={"messages": {user1_id: {"d1": test_msg}}}, - access_token=user2_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - sync_body = { - "lists": {}, - "extensions": { - "to_device": { - "enabled": True, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - self._assert_to_device_response( - response_body, - [{"content": test_msg, "sender": user2_id, "type": "m.test"}], - ) - - def test_data_incremental_sync(self) -> None: - """Test that we get to-device messages over incremental syncs""" - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass", "d1") - user2_id = self.register_user("u2", "pass") - user2_tok = self.login(user2_id, "pass", "d2") - - sync_body: JsonDict = { - "lists": {}, - "extensions": { - "to_device": { - "enabled": True, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - # No to-device messages yet. - next_batch = self._assert_to_device_response(response_body, []) - - test_msg = {"foo": "bar"} - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.test/1234", - content={"messages": {user1_id: {"d1": test_msg}}}, - access_token=user2_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - sync_body = { - "lists": {}, - "extensions": { - "to_device": { - "enabled": True, - "since": next_batch, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - next_batch = self._assert_to_device_response( - response_body, - [{"content": test_msg, "sender": user2_id, "type": "m.test"}], - ) - - # The next sliding sync request should not include the to-device - # message. - sync_body = { - "lists": {}, - "extensions": { - "to_device": { - "enabled": True, - "since": next_batch, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - self._assert_to_device_response(response_body, []) - - # An initial sliding sync request should not include the to-device - # message, as it should have been deleted - sync_body = { - "lists": {}, - "extensions": { - "to_device": { - "enabled": True, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - self._assert_to_device_response(response_body, []) - - def test_wait_for_new_data(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive. - - (Only applies to incremental syncs with a `timeout` specified) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass", "d1") - user2_id = self.register_user("u2", "pass") - user2_tok = self.login(user2_id, "pass", "d2") - - sync_body = { - "lists": {}, - "extensions": { - "to_device": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make the Sliding Sync request - channel = self.make_request( - "POST", - self.sync_endpoint + "?timeout=10000" + f"&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Bump the to-device messages to trigger new results - test_msg = {"foo": "bar"} - send_to_device_channel = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.test/1234", - content={"messages": {user1_id: {"d1": test_msg}}}, - access_token=user2_tok, - ) - self.assertEqual( - send_to_device_channel.code, 200, send_to_device_channel.result - ) - # Should respond before the 10 second timeout - channel.await_result(timeout_ms=3000) - self.assertEqual(channel.code, 200, channel.json_body) - - self._assert_to_device_response( - channel.json_body, - [{"content": test_msg, "sender": user2_id, "type": "m.test"}], - ) - - def test_wait_for_new_data_timeout(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive but - no data ever arrives so we timeout. We're also making sure that the default data - from the To-Device extension doesn't trigger a false-positive for new data. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - sync_body = { - "lists": {}, - "extensions": { - "to_device": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make the Sliding Sync request - channel = self.make_request( - "POST", - self.sync_endpoint + "?timeout=10000" + f"&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Wake-up `notifier.wait_for_events(...)` that will cause us test - # `SlidingSyncResult.__bool__` for new results. - self._bump_notifier_wait_for_events( - user1_id, wake_stream_key=StreamKeyType.ACCOUNT_DATA - ) - # Block for a little bit more to ensure we don't see any new results. - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=4000) - # Wait for the sync to complete (wait for the rest of the 10 second timeout, - # 5000 + 4000 + 1200 > 10000) - channel.await_result(timeout_ms=1200) - self.assertEqual(channel.code, 200, channel.json_body) - - self._assert_to_device_response(channel.json_body, []) diff --git a/tests/rest/client/sliding_sync/test_extension_typing.py b/tests/rest/client/sliding_sync/test_extension_typing.py deleted file mode 100644
index 7f523e0f10..0000000000 --- a/tests/rest/client/sliding_sync/test_extension_typing.py +++ /dev/null
@@ -1,482 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.api.constants import EduTypes -from synapse.rest.client import login, room, sync -from synapse.server import HomeServer -from synapse.types import StreamKeyType -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase -from tests.server import TimedOutException - -logger = logging.getLogger(__name__) - - -class SlidingSyncTypingExtensionTestCase(SlidingSyncBase): - """Tests for the typing notification sliding sync extension""" - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - - def test_no_data_initial_sync(self) -> None: - """ - Test that enabling the typing extension works during an intitial sync, - even if there is no-data. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Make an initial Sliding Sync request with the typing extension enabled - sync_body = { - "lists": {}, - "extensions": { - "typing": { - "enabled": True, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - self.assertIncludes( - response_body["extensions"]["typing"].get("rooms").keys(), - set(), - exact=True, - ) - - def test_no_data_incremental_sync(self) -> None: - """ - Test that enabling typing extension works during an incremental sync, even - if there is no-data. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - sync_body = { - "lists": {}, - "extensions": { - "typing": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make an incremental Sliding Sync request with the typing extension enabled - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - self.assertIncludes( - response_body["extensions"]["typing"].get("rooms").keys(), - set(), - exact=True, - ) - - def test_typing_initial_sync(self) -> None: - """ - On initial sync, we return all typing notifications for rooms that we request - and are being returned in the Sliding Sync response. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass") - user4_id = self.register_user("user4", "pass") - user4_tok = self.login(user4_id, "pass") - - # Create a room - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.join(room_id1, user3_id, tok=user3_tok) - self.helper.join(room_id1, user4_id, tok=user4_tok) - # User1 starts typing in room1 - channel = self.make_request( - "PUT", - f"/rooms/{room_id1}/typing/{user1_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User2 starts typing in room1 - channel = self.make_request( - "PUT", - f"/rooms/{room_id1}/typing/{user2_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Create another room - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id2, user1_id, tok=user1_tok) - self.helper.join(room_id2, user3_id, tok=user3_tok) - self.helper.join(room_id2, user4_id, tok=user4_tok) - # User1 starts typing in room2 - channel = self.make_request( - "PUT", - f"/rooms/{room_id2}/typing/{user1_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User2 starts typing in room2 - channel = self.make_request( - "PUT", - f"/rooms/{room_id2}/typing/{user2_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Make an initial Sliding Sync request with the typing extension enabled - sync_body = { - "lists": {}, - "room_subscriptions": { - room_id1: { - "required_state": [], - "timeline_limit": 0, - } - }, - "extensions": { - "typing": { - "enabled": True, - "rooms": [room_id1, room_id2], - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Even though we requested room2, we only expect room1 to show up because that's - # the only room in the Sliding Sync response (room2 is not one of our room - # subscriptions or in a sliding window list). - self.assertIncludes( - response_body["extensions"]["typing"].get("rooms").keys(), - {room_id1}, - exact=True, - ) - # Sanity check that it's the correct ephemeral event type - self.assertEqual( - response_body["extensions"]["typing"]["rooms"][room_id1]["type"], - EduTypes.TYPING, - ) - # We can see user1 and user2 typing - self.assertIncludes( - set( - response_body["extensions"]["typing"]["rooms"][room_id1]["content"][ - "user_ids" - ] - ), - {user1_id, user2_id}, - exact=True, - ) - - def test_typing_incremental_sync(self) -> None: - """ - On incremental sync, we return all typing notifications in the token range for a - given room but only for rooms that we request and are being returned in the - Sliding Sync response. - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass") - - # Create room1 - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.join(room_id1, user3_id, tok=user3_tok) - # User2 starts typing in room1 - channel = self.make_request( - "PUT", - f"/rooms/{room_id1}/typing/{user2_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user2_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Create room2 - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id2, user1_id, tok=user1_tok) - # User1 starts typing in room2 (before the `from_token`) - channel = self.make_request( - "PUT", - f"/rooms/{room_id2}/typing/{user1_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Create room3 - room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id3, user1_id, tok=user1_tok) - self.helper.join(room_id3, user3_id, tok=user3_tok) - - # Create room4 - room_id4 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id4, user1_id, tok=user1_tok) - self.helper.join(room_id4, user3_id, tok=user3_tok) - # User1 starts typing in room4 (before the `from_token`) - channel = self.make_request( - "PUT", - f"/rooms/{room_id4}/typing/{user1_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Advance time so all of the typing notifications timeout before we make our - # Sliding Sync requests. Even though these are sent before the `from_token`, the - # typing code only keeps track of stream position of the latest typing - # notification so "old" typing notifications that are still "alive" (haven't - # timed out) can appear in the response. - self.reactor.advance(36) - - sync_body = { - "lists": {}, - "room_subscriptions": { - room_id1: { - "required_state": [], - "timeline_limit": 0, - }, - room_id3: { - "required_state": [], - "timeline_limit": 0, - }, - room_id4: { - "required_state": [], - "timeline_limit": 0, - }, - }, - "extensions": { - "typing": { - "enabled": True, - "rooms": [room_id1, room_id2, room_id3, room_id4], - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Add some more typing notifications after the `from_token` - # - # User1 starts typing in room1 - channel = self.make_request( - "PUT", - f"/rooms/{room_id1}/typing/{user1_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User1 starts typing in room2 - channel = self.make_request( - "PUT", - f"/rooms/{room_id2}/typing/{user1_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # User3 starts typing in room3 - channel = self.make_request( - "PUT", - f"/rooms/{room_id3}/typing/{user3_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user3_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - # No activity for room4 after the `from_token` - - # Make an incremental Sliding Sync request with the typing extension enabled - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # Even though we requested room2, we only expect rooms to show up if they are - # already in the Sliding Sync response. room4 doesn't show up because there is - # no activity after the `from_token`. - self.assertIncludes( - response_body["extensions"]["typing"].get("rooms").keys(), - {room_id1, room_id3}, - exact=True, - ) - - # Check room1: - # - # Sanity check that it's the correct ephemeral event type - self.assertEqual( - response_body["extensions"]["typing"]["rooms"][room_id1]["type"], - EduTypes.TYPING, - ) - # We only see that user1 is typing in room1 since the `from_token` - self.assertIncludes( - set( - response_body["extensions"]["typing"]["rooms"][room_id1]["content"][ - "user_ids" - ] - ), - {user1_id}, - exact=True, - ) - - # Check room3: - # - # Sanity check that it's the correct ephemeral event type - self.assertEqual( - response_body["extensions"]["typing"]["rooms"][room_id3]["type"], - EduTypes.TYPING, - ) - # We only see that user3 is typing in room1 since the `from_token` - self.assertIncludes( - set( - response_body["extensions"]["typing"]["rooms"][room_id3]["content"][ - "user_ids" - ] - ), - {user3_id}, - exact=True, - ) - - def test_wait_for_new_data(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive. - - (Only applies to incremental syncs with a `timeout` specified) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id, user1_id, tok=user1_tok) - - sync_body = { - "lists": {}, - "room_subscriptions": { - room_id: { - "required_state": [], - "timeline_limit": 0, - }, - }, - "extensions": { - "typing": { - "enabled": True, - "rooms": [room_id], - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make an incremental Sliding Sync request with the typing extension enabled - channel = self.make_request( - "POST", - self.sync_endpoint + f"?timeout=10000&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Bump the typing status to trigger new results - typing_channel = self.make_request( - "PUT", - f"/rooms/{room_id}/typing/{user2_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user2_tok, - ) - self.assertEqual(typing_channel.code, 200, typing_channel.json_body) - # Should respond before the 10 second timeout - channel.await_result(timeout_ms=3000) - self.assertEqual(channel.code, 200, channel.json_body) - - # We should see the new typing notification - self.assertIncludes( - channel.json_body.get("extensions", {}) - .get("typing", {}) - .get("rooms", {}) - .keys(), - {room_id}, - exact=True, - message=str(channel.json_body), - ) - self.assertIncludes( - set( - channel.json_body["extensions"]["typing"]["rooms"][room_id]["content"][ - "user_ids" - ] - ), - {user2_id}, - exact=True, - ) - - def test_wait_for_new_data_timeout(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive but - no data ever arrives so we timeout. We're also making sure that the default data - from the typing extension doesn't trigger a false-positive for new data. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - sync_body = { - "lists": {}, - "extensions": { - "typing": { - "enabled": True, - } - }, - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make the Sliding Sync request - channel = self.make_request( - "POST", - self.sync_endpoint + f"?timeout=10000&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Wake-up `notifier.wait_for_events(...)` that will cause us test - # `SlidingSyncResult.__bool__` for new results. - self._bump_notifier_wait_for_events( - user1_id, wake_stream_key=StreamKeyType.ACCOUNT_DATA - ) - # Block for a little bit more to ensure we don't see any new results. - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=4000) - # Wait for the sync to complete (wait for the rest of the 10 second timeout, - # 5000 + 4000 + 1200 > 10000) - channel.await_result(timeout_ms=1200) - self.assertEqual(channel.code, 200, channel.json_body) - - self.assertIncludes( - channel.json_body["extensions"]["typing"].get("rooms").keys(), - set(), - exact=True, - ) diff --git a/tests/rest/client/sliding_sync/test_extensions.py b/tests/rest/client/sliding_sync/test_extensions.py deleted file mode 100644
index 68f6661334..0000000000 --- a/tests/rest/client/sliding_sync/test_extensions.py +++ /dev/null
@@ -1,283 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging -from typing import Literal - -from parameterized import parameterized -from typing_extensions import assert_never - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.api.constants import ReceiptTypes -from synapse.rest.client import login, receipts, room, sync -from synapse.server import HomeServer -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase - -logger = logging.getLogger(__name__) - - -class SlidingSyncExtensionsTestCase(SlidingSyncBase): - """ - Test general extensions behavior in the Sliding Sync API. Each extension has their - own suite of tests in their own file as well. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - receipts.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - self.account_data_handler = hs.get_account_data_handler() - - # Any extensions that use `lists`/`rooms` should be tested here - @parameterized.expand([("account_data",), ("receipts",), ("typing",)]) - def test_extensions_lists_rooms_relevant_rooms( - self, - extension_name: Literal["account_data", "receipts", "typing"], - ) -> None: - """ - With various extensions, test out requesting different variations of - `lists`/`rooms`. - - Stresses `SlidingSyncHandler.find_relevant_room_ids_for_extension(...)` - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create some rooms - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) - room_id3 = self.helper.create_room_as(user1_id, tok=user1_tok) - room_id4 = self.helper.create_room_as(user1_id, tok=user1_tok) - room_id5 = self.helper.create_room_as(user1_id, tok=user1_tok) - - room_id_to_human_name_map = { - room_id1: "room1", - room_id2: "room2", - room_id3: "room3", - room_id4: "room4", - room_id5: "room5", - } - - for room_id in room_id_to_human_name_map.keys(): - if extension_name == "account_data": - # Add some account data to each room - self.get_success( - self.account_data_handler.add_account_data_to_room( - user_id=user1_id, - room_id=room_id, - account_data_type="org.matrix.roorarraz", - content={"roo": "rar"}, - ) - ) - elif extension_name == "receipts": - event_response = self.helper.send( - room_id, body="new event", tok=user1_tok - ) - # Read last event - channel = self.make_request( - "POST", - f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_response['event_id']}", - {}, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - elif extension_name == "typing": - # Start a typing notification - channel = self.make_request( - "PUT", - f"/rooms/{room_id}/typing/{user1_id}", - b'{"typing": true, "timeout": 30000}', - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - else: - assert_never(extension_name) - - main_sync_body = { - "lists": { - # We expect this list range to include room5 and room4 - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 0, - }, - # We expect this list range to include room5, room4, room3 - "bar-list": { - "ranges": [[0, 2]], - "required_state": [], - "timeline_limit": 0, - }, - }, - "room_subscriptions": { - room_id1: { - "required_state": [], - "timeline_limit": 0, - } - }, - } - - # Mix lists and rooms - sync_body = { - **main_sync_body, - "extensions": { - extension_name: { - "enabled": True, - "lists": ["foo-list", "non-existent-list"], - "rooms": [room_id1, room_id2, "!non-existent-room"], - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # room1: ✅ Requested via `rooms` and a room subscription exists - # room2: ❌ Requested via `rooms` but not in the response (from lists or room subscriptions) - # room3: ❌ Not requested - # room4: ✅ Shows up because requested via `lists` and list exists in the response - # room5: ✅ Shows up because requested via `lists` and list exists in the response - self.assertIncludes( - { - room_id_to_human_name_map[room_id] - for room_id in response_body["extensions"][extension_name] - .get("rooms") - .keys() - }, - {"room1", "room4", "room5"}, - exact=True, - ) - - # Try wildcards (this is the default) - sync_body = { - **main_sync_body, - "extensions": { - extension_name: { - "enabled": True, - # "lists": ["*"], - # "rooms": ["*"], - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # room1: ✅ Shows up because of default `rooms` wildcard and is in one of the room subscriptions - # room2: ❌ Not requested - # room3: ✅ Shows up because of default `lists` wildcard and is in a list - # room4: ✅ Shows up because of default `lists` wildcard and is in a list - # room5: ✅ Shows up because of default `lists` wildcard and is in a list - self.assertIncludes( - { - room_id_to_human_name_map[room_id] - for room_id in response_body["extensions"][extension_name] - .get("rooms") - .keys() - }, - {"room1", "room3", "room4", "room5"}, - exact=True, - ) - - # Empty list will return nothing - sync_body = { - **main_sync_body, - "extensions": { - extension_name: { - "enabled": True, - "lists": [], - "rooms": [], - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # room1: ❌ Not requested - # room2: ❌ Not requested - # room3: ❌ Not requested - # room4: ❌ Not requested - # room5: ❌ Not requested - self.assertIncludes( - { - room_id_to_human_name_map[room_id] - for room_id in response_body["extensions"][extension_name] - .get("rooms") - .keys() - }, - set(), - exact=True, - ) - - # Try wildcard and none - sync_body = { - **main_sync_body, - "extensions": { - extension_name: { - "enabled": True, - "lists": ["*"], - "rooms": [], - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # room1: ❌ Not requested - # room2: ❌ Not requested - # room3: ✅ Shows up because of default `lists` wildcard and is in a list - # room4: ✅ Shows up because of default `lists` wildcard and is in a list - # room5: ✅ Shows up because of default `lists` wildcard and is in a list - self.assertIncludes( - { - room_id_to_human_name_map[room_id] - for room_id in response_body["extensions"][extension_name] - .get("rooms") - .keys() - }, - {"room3", "room4", "room5"}, - exact=True, - ) - - # Try requesting a room that is only in a list - sync_body = { - **main_sync_body, - "extensions": { - extension_name: { - "enabled": True, - "lists": [], - "rooms": [room_id5], - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # room1: ❌ Not requested - # room2: ❌ Not requested - # room3: ❌ Not requested - # room4: ❌ Not requested - # room5: ✅ Requested via `rooms` and is in a list - self.assertIncludes( - { - room_id_to_human_name_map[room_id] - for room_id in response_body["extensions"][extension_name] - .get("rooms") - .keys() - }, - {"room5"}, - exact=True, - ) diff --git a/tests/rest/client/sliding_sync/test_room_subscriptions.py b/tests/rest/client/sliding_sync/test_room_subscriptions.py deleted file mode 100644
index cc17b0b354..0000000000 --- a/tests/rest/client/sliding_sync/test_room_subscriptions.py +++ /dev/null
@@ -1,285 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging -from http import HTTPStatus - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.api.constants import EventTypes, HistoryVisibility -from synapse.rest.client import login, room, sync -from synapse.server import HomeServer -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase - -logger = logging.getLogger(__name__) - - -class SlidingSyncRoomSubscriptionsTestCase(SlidingSyncBase): - """ - Test `room_subscriptions` in the Sliding Sync API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - - def test_room_subscriptions_with_join_membership(self) -> None: - """ - Test `room_subscriptions` with a joined room should give us timeline and current - state events. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Make the Sliding Sync request with just the room subscription - sync_body = { - "room_subscriptions": { - room_id1: { - "required_state": [ - [EventTypes.Create, ""], - ], - "timeline_limit": 1, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - # We should see some state - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Create, "")], - }, - exact=True, - ) - self.assertIsNone(response_body["rooms"][room_id1].get("invite_state")) - - # We should see some events - self.assertEqual( - [ - event["event_id"] - for event in response_body["rooms"][room_id1]["timeline"] - ], - [ - join_response["event_id"], - ], - response_body["rooms"][room_id1]["timeline"], - ) - # No "live" events in an initial sync (no `from_token` to define the "live" - # range) - self.assertEqual( - response_body["rooms"][room_id1]["num_live"], - 0, - response_body["rooms"][room_id1], - ) - # There are more events to paginate to - self.assertEqual( - response_body["rooms"][room_id1]["limited"], - True, - response_body["rooms"][room_id1], - ) - - def test_room_subscriptions_with_leave_membership(self) -> None: - """ - Test `room_subscriptions` with a leave room should give us timeline and state - events up to the leave event. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.send_state( - room_id1, - event_type="org.matrix.foo_state", - state_key="", - body={"foo": "bar"}, - tok=user2_tok, - ) - - join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - # Send some events after user1 leaves - self.helper.send(room_id1, "activity after leave", tok=user2_tok) - # Update state after user1 leaves - self.helper.send_state( - room_id1, - event_type="org.matrix.foo_state", - state_key="", - body={"foo": "qux"}, - tok=user2_tok, - ) - - # Make the Sliding Sync request with just the room subscription - sync_body = { - "room_subscriptions": { - room_id1: { - "required_state": [ - ["org.matrix.foo_state", ""], - ], - "timeline_limit": 2, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # We should see the state at the time of the leave - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[("org.matrix.foo_state", "")], - }, - exact=True, - ) - self.assertIsNone(response_body["rooms"][room_id1].get("invite_state")) - - # We should see some before we left (nothing after) - self.assertEqual( - [ - event["event_id"] - for event in response_body["rooms"][room_id1]["timeline"] - ], - [ - join_response["event_id"], - leave_response["event_id"], - ], - response_body["rooms"][room_id1]["timeline"], - ) - # No "live" events in an initial sync (no `from_token` to define the "live" - # range) - self.assertEqual( - response_body["rooms"][room_id1]["num_live"], - 0, - response_body["rooms"][room_id1], - ) - # There are more events to paginate to - self.assertEqual( - response_body["rooms"][room_id1]["limited"], - True, - response_body["rooms"][room_id1], - ) - - def test_room_subscriptions_no_leak_private_room(self) -> None: - """ - Test `room_subscriptions` with a private room we have never been in should not - leak any data to the user. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=False) - - # We should not be able to join the private room - self.helper.join( - room_id1, user1_id, tok=user1_tok, expect_code=HTTPStatus.FORBIDDEN - ) - - # Make the Sliding Sync request with just the room subscription - sync_body = { - "room_subscriptions": { - room_id1: { - "required_state": [ - [EventTypes.Create, ""], - ], - "timeline_limit": 1, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # We should not see the room at all (we're not in it) - self.assertIsNone(response_body["rooms"].get(room_id1), response_body["rooms"]) - - def test_room_subscriptions_world_readable(self) -> None: - """ - Test `room_subscriptions` with a room that has `world_readable` history visibility - - FIXME: We should be able to see the room timeline and state - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Create a room with `world_readable` history visibility - room_id1 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - "preset": "public_chat", - "initial_state": [ - { - "content": { - "history_visibility": HistoryVisibility.WORLD_READABLE - }, - "state_key": "", - "type": EventTypes.RoomHistoryVisibility, - } - ], - }, - ) - # Ensure we're testing with a room with `world_readable` history visibility - # which means events are visible to anyone even without membership. - history_visibility_response = self.helper.get_state( - room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok - ) - self.assertEqual( - history_visibility_response.get("history_visibility"), - HistoryVisibility.WORLD_READABLE, - ) - - # Note: We never join the room - - # Make the Sliding Sync request with just the room subscription - sync_body = { - "room_subscriptions": { - room_id1: { - "required_state": [ - [EventTypes.Create, ""], - ], - "timeline_limit": 1, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # FIXME: In the future, we should be able to see the room because it's - # `world_readable` but currently we don't support this. - self.assertIsNone(response_body["rooms"].get(room_id1), response_body["rooms"]) diff --git a/tests/rest/client/sliding_sync/test_rooms_invites.py b/tests/rest/client/sliding_sync/test_rooms_invites.py deleted file mode 100644
index f08ffaf674..0000000000 --- a/tests/rest/client/sliding_sync/test_rooms_invites.py +++ /dev/null
@@ -1,510 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.api.constants import EventTypes, HistoryVisibility -from synapse.rest.client import login, room, sync -from synapse.server import HomeServer -from synapse.types import UserID -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase - -logger = logging.getLogger(__name__) - - -class SlidingSyncRoomsInvitesTestCase(SlidingSyncBase): - """ - Test to make sure the `rooms` response looks good for invites in the Sliding Sync API. - - Invites behave a lot different than other rooms because we don't include the - `timeline` (`num_live`, `limited`, `prev_batch`) or `required_state` in favor of - some stripped state under the `invite_state` key. - - Knocks probably have the same behavior but the spec doesn't mention knocks yet. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - - def test_rooms_invite_shared_history_initial_sync(self) -> None: - """ - Test that `rooms` we are invited to have some stripped `invite_state` during an - initial sync. - - This is an `invite` room so we should only have `stripped_state` (no `timeline`) - but we also shouldn't see any timeline events because the history visiblity is - `shared` and we haven't joined the room yet. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user1 = UserID.from_string(user1_id) - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user2 = UserID.from_string(user2_id) - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - # Ensure we're testing with a room with `shared` history visibility which means - # history visible until you actually join the room. - history_visibility_response = self.helper.get_state( - room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok - ) - self.assertEqual( - history_visibility_response.get("history_visibility"), - HistoryVisibility.SHARED, - ) - - self.helper.send(room_id1, "activity before1", tok=user2_tok) - self.helper.send(room_id1, "activity before2", tok=user2_tok) - self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - self.helper.send(room_id1, "activity after3", tok=user2_tok) - self.helper.send(room_id1, "activity after4", tok=user2_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 3, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # `timeline` is omitted for `invite` rooms with `stripped_state` - self.assertIsNone( - response_body["rooms"][room_id1].get("timeline"), - response_body["rooms"][room_id1], - ) - # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("num_live"), - response_body["rooms"][room_id1], - ) - # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("limited"), - response_body["rooms"][room_id1], - ) - # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("prev_batch"), - response_body["rooms"][room_id1], - ) - # `required_state` is omitted for `invite` rooms with `stripped_state` - self.assertIsNone( - response_body["rooms"][room_id1].get("required_state"), - response_body["rooms"][room_id1], - ) - # We should have some `stripped_state` so the potential joiner can identify the - # room (we don't care about the order). - self.assertCountEqual( - response_body["rooms"][room_id1]["invite_state"], - [ - { - "content": {"creator": user2_id, "room_version": "10"}, - "sender": user2_id, - "state_key": "", - "type": "m.room.create", - }, - { - "content": {"join_rule": "public"}, - "sender": user2_id, - "state_key": "", - "type": "m.room.join_rules", - }, - { - "content": {"displayname": user2.localpart, "membership": "join"}, - "sender": user2_id, - "state_key": user2_id, - "type": "m.room.member", - }, - { - "content": {"displayname": user1.localpart, "membership": "invite"}, - "sender": user2_id, - "state_key": user1_id, - "type": "m.room.member", - }, - ], - response_body["rooms"][room_id1]["invite_state"], - ) - - def test_rooms_invite_shared_history_incremental_sync(self) -> None: - """ - Test that `rooms` we are invited to have some stripped `invite_state` during an - incremental sync. - - This is an `invite` room so we should only have `stripped_state` (no `timeline`) - but we also shouldn't see any timeline events because the history visiblity is - `shared` and we haven't joined the room yet. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user1 = UserID.from_string(user1_id) - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user2 = UserID.from_string(user2_id) - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - # Ensure we're testing with a room with `shared` history visibility which means - # history visible until you actually join the room. - history_visibility_response = self.helper.get_state( - room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok - ) - self.assertEqual( - history_visibility_response.get("history_visibility"), - HistoryVisibility.SHARED, - ) - - self.helper.send(room_id1, "activity before invite1", tok=user2_tok) - self.helper.send(room_id1, "activity before invite2", tok=user2_tok) - self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - self.helper.send(room_id1, "activity after invite3", tok=user2_tok) - self.helper.send(room_id1, "activity after invite4", tok=user2_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 3, - } - } - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - self.helper.send(room_id1, "activity after token5", tok=user2_tok) - self.helper.send(room_id1, "activity after toekn6", tok=user2_tok) - - # Make the Sliding Sync request - response_body, from_token = self.do_sync( - sync_body, since=from_token, tok=user1_tok - ) - - # `timeline` is omitted for `invite` rooms with `stripped_state` - self.assertIsNone( - response_body["rooms"][room_id1].get("timeline"), - response_body["rooms"][room_id1], - ) - # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("num_live"), - response_body["rooms"][room_id1], - ) - # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("limited"), - response_body["rooms"][room_id1], - ) - # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("prev_batch"), - response_body["rooms"][room_id1], - ) - # `required_state` is omitted for `invite` rooms with `stripped_state` - self.assertIsNone( - response_body["rooms"][room_id1].get("required_state"), - response_body["rooms"][room_id1], - ) - # We should have some `stripped_state` so the potential joiner can identify the - # room (we don't care about the order). - self.assertCountEqual( - response_body["rooms"][room_id1]["invite_state"], - [ - { - "content": {"creator": user2_id, "room_version": "10"}, - "sender": user2_id, - "state_key": "", - "type": "m.room.create", - }, - { - "content": {"join_rule": "public"}, - "sender": user2_id, - "state_key": "", - "type": "m.room.join_rules", - }, - { - "content": {"displayname": user2.localpart, "membership": "join"}, - "sender": user2_id, - "state_key": user2_id, - "type": "m.room.member", - }, - { - "content": {"displayname": user1.localpart, "membership": "invite"}, - "sender": user2_id, - "state_key": user1_id, - "type": "m.room.member", - }, - ], - response_body["rooms"][room_id1]["invite_state"], - ) - - def test_rooms_invite_world_readable_history_initial_sync(self) -> None: - """ - Test that `rooms` we are invited to have some stripped `invite_state` during an - initial sync. - - This is an `invite` room so we should only have `stripped_state` (no `timeline`) - but depending on the semantics we decide, we could potentially see some - historical events before/after the `from_token` because the history is - `world_readable`. Same situation for events after the `from_token` if the - history visibility was set to `invited`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user1 = UserID.from_string(user1_id) - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user2 = UserID.from_string(user2_id) - - room_id1 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - "preset": "public_chat", - "initial_state": [ - { - "content": { - "history_visibility": HistoryVisibility.WORLD_READABLE - }, - "state_key": "", - "type": EventTypes.RoomHistoryVisibility, - } - ], - }, - ) - # Ensure we're testing with a room with `world_readable` history visibility - # which means events are visible to anyone even without membership. - history_visibility_response = self.helper.get_state( - room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok - ) - self.assertEqual( - history_visibility_response.get("history_visibility"), - HistoryVisibility.WORLD_READABLE, - ) - - self.helper.send(room_id1, "activity before1", tok=user2_tok) - self.helper.send(room_id1, "activity before2", tok=user2_tok) - self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - self.helper.send(room_id1, "activity after3", tok=user2_tok) - self.helper.send(room_id1, "activity after4", tok=user2_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - # Large enough to see the latest events and before the invite - "timeline_limit": 4, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # `timeline` is omitted for `invite` rooms with `stripped_state` - self.assertIsNone( - response_body["rooms"][room_id1].get("timeline"), - response_body["rooms"][room_id1], - ) - # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("num_live"), - response_body["rooms"][room_id1], - ) - # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("limited"), - response_body["rooms"][room_id1], - ) - # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("prev_batch"), - response_body["rooms"][room_id1], - ) - # `required_state` is omitted for `invite` rooms with `stripped_state` - self.assertIsNone( - response_body["rooms"][room_id1].get("required_state"), - response_body["rooms"][room_id1], - ) - # We should have some `stripped_state` so the potential joiner can identify the - # room (we don't care about the order). - self.assertCountEqual( - response_body["rooms"][room_id1]["invite_state"], - [ - { - "content": {"creator": user2_id, "room_version": "10"}, - "sender": user2_id, - "state_key": "", - "type": "m.room.create", - }, - { - "content": {"join_rule": "public"}, - "sender": user2_id, - "state_key": "", - "type": "m.room.join_rules", - }, - { - "content": {"displayname": user2.localpart, "membership": "join"}, - "sender": user2_id, - "state_key": user2_id, - "type": "m.room.member", - }, - { - "content": {"displayname": user1.localpart, "membership": "invite"}, - "sender": user2_id, - "state_key": user1_id, - "type": "m.room.member", - }, - ], - response_body["rooms"][room_id1]["invite_state"], - ) - - def test_rooms_invite_world_readable_history_incremental_sync(self) -> None: - """ - Test that `rooms` we are invited to have some stripped `invite_state` during an - incremental sync. - - This is an `invite` room so we should only have `stripped_state` (no `timeline`) - but depending on the semantics we decide, we could potentially see some - historical events before/after the `from_token` because the history is - `world_readable`. Same situation for events after the `from_token` if the - history visibility was set to `invited`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user1 = UserID.from_string(user1_id) - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user2 = UserID.from_string(user2_id) - - room_id1 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - "preset": "public_chat", - "initial_state": [ - { - "content": { - "history_visibility": HistoryVisibility.WORLD_READABLE - }, - "state_key": "", - "type": EventTypes.RoomHistoryVisibility, - } - ], - }, - ) - # Ensure we're testing with a room with `world_readable` history visibility - # which means events are visible to anyone even without membership. - history_visibility_response = self.helper.get_state( - room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok - ) - self.assertEqual( - history_visibility_response.get("history_visibility"), - HistoryVisibility.WORLD_READABLE, - ) - - self.helper.send(room_id1, "activity before invite1", tok=user2_tok) - self.helper.send(room_id1, "activity before invite2", tok=user2_tok) - self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - self.helper.send(room_id1, "activity after invite3", tok=user2_tok) - self.helper.send(room_id1, "activity after invite4", tok=user2_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - # Large enough to see the latest events and before the invite - "timeline_limit": 4, - } - } - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - self.helper.send(room_id1, "activity after token5", tok=user2_tok) - self.helper.send(room_id1, "activity after toekn6", tok=user2_tok) - - # Make the incremental Sliding Sync request - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # `timeline` is omitted for `invite` rooms with `stripped_state` - self.assertIsNone( - response_body["rooms"][room_id1].get("timeline"), - response_body["rooms"][room_id1], - ) - # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("num_live"), - response_body["rooms"][room_id1], - ) - # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("limited"), - response_body["rooms"][room_id1], - ) - # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway) - self.assertIsNone( - response_body["rooms"][room_id1].get("prev_batch"), - response_body["rooms"][room_id1], - ) - # `required_state` is omitted for `invite` rooms with `stripped_state` - self.assertIsNone( - response_body["rooms"][room_id1].get("required_state"), - response_body["rooms"][room_id1], - ) - # We should have some `stripped_state` so the potential joiner can identify the - # room (we don't care about the order). - self.assertCountEqual( - response_body["rooms"][room_id1]["invite_state"], - [ - { - "content": {"creator": user2_id, "room_version": "10"}, - "sender": user2_id, - "state_key": "", - "type": "m.room.create", - }, - { - "content": {"join_rule": "public"}, - "sender": user2_id, - "state_key": "", - "type": "m.room.join_rules", - }, - { - "content": {"displayname": user2.localpart, "membership": "join"}, - "sender": user2_id, - "state_key": user2_id, - "type": "m.room.member", - }, - { - "content": {"displayname": user1.localpart, "membership": "invite"}, - "sender": user2_id, - "state_key": user1_id, - "type": "m.room.member", - }, - ], - response_body["rooms"][room_id1]["invite_state"], - ) diff --git a/tests/rest/client/sliding_sync/test_rooms_meta.py b/tests/rest/client/sliding_sync/test_rooms_meta.py deleted file mode 100644
index 04f11c0524..0000000000 --- a/tests/rest/client/sliding_sync/test_rooms_meta.py +++ /dev/null
@@ -1,710 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import RoomVersions -from synapse.rest.client import login, room, sync -from synapse.server import HomeServer -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase -from tests.test_utils.event_injection import create_event - -logger = logging.getLogger(__name__) - - -class SlidingSyncRoomsMetaTestCase(SlidingSyncBase): - """ - Test rooms meta info like name, avatar, joined_count, invited_count, is_dm, - bump_stamp in the Sliding Sync API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - - def test_rooms_meta_when_joined(self) -> None: - """ - Test that the `rooms` `name` and `avatar` are included in the response and - reflect the current state of the room when the user is joined to the room. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - "name": "my super room", - }, - ) - # Set the room avatar URL - self.helper.send_state( - room_id1, - EventTypes.RoomAvatar, - {"url": "mxc://DUMMY_MEDIA_ID"}, - tok=user2_tok, - ) - - self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 0, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Reflect the current state of the room - self.assertEqual( - response_body["rooms"][room_id1]["name"], - "my super room", - response_body["rooms"][room_id1], - ) - self.assertEqual( - response_body["rooms"][room_id1]["avatar"], - "mxc://DUMMY_MEDIA_ID", - response_body["rooms"][room_id1], - ) - self.assertEqual( - response_body["rooms"][room_id1]["joined_count"], - 2, - ) - self.assertEqual( - response_body["rooms"][room_id1]["invited_count"], - 0, - ) - self.assertIsNone( - response_body["rooms"][room_id1].get("is_dm"), - ) - - def test_rooms_meta_when_invited(self) -> None: - """ - Test that the `rooms` `name` and `avatar` are included in the response and - reflect the current state of the room when the user is invited to the room. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - "name": "my super room", - }, - ) - # Set the room avatar URL - self.helper.send_state( - room_id1, - EventTypes.RoomAvatar, - {"url": "mxc://DUMMY_MEDIA_ID"}, - tok=user2_tok, - ) - - # User1 is invited to the room - self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - - # Update the room name after user1 has left - self.helper.send_state( - room_id1, - EventTypes.Name, - {"name": "my super duper room"}, - tok=user2_tok, - ) - # Update the room avatar URL after user1 has left - self.helper.send_state( - room_id1, - EventTypes.RoomAvatar, - {"url": "mxc://UPDATED_DUMMY_MEDIA_ID"}, - tok=user2_tok, - ) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 0, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # This should still reflect the current state of the room even when the user is - # invited. - self.assertEqual( - response_body["rooms"][room_id1]["name"], - "my super duper room", - response_body["rooms"][room_id1], - ) - self.assertEqual( - response_body["rooms"][room_id1]["avatar"], - "mxc://UPDATED_DUMMY_MEDIA_ID", - response_body["rooms"][room_id1], - ) - self.assertEqual( - response_body["rooms"][room_id1]["joined_count"], - 1, - ) - self.assertEqual( - response_body["rooms"][room_id1]["invited_count"], - 1, - ) - self.assertIsNone( - response_body["rooms"][room_id1].get("is_dm"), - ) - - def test_rooms_meta_when_banned(self) -> None: - """ - Test that the `rooms` `name` and `avatar` reflect the state of the room when the - user was banned (do not leak current state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - "name": "my super room", - }, - ) - # Set the room avatar URL - self.helper.send_state( - room_id1, - EventTypes.RoomAvatar, - {"url": "mxc://DUMMY_MEDIA_ID"}, - tok=user2_tok, - ) - - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - - # Update the room name after user1 has left - self.helper.send_state( - room_id1, - EventTypes.Name, - {"name": "my super duper room"}, - tok=user2_tok, - ) - # Update the room avatar URL after user1 has left - self.helper.send_state( - room_id1, - EventTypes.RoomAvatar, - {"url": "mxc://UPDATED_DUMMY_MEDIA_ID"}, - tok=user2_tok, - ) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 0, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Reflect the state of the room at the time of leaving - self.assertEqual( - response_body["rooms"][room_id1]["name"], - "my super room", - response_body["rooms"][room_id1], - ) - self.assertEqual( - response_body["rooms"][room_id1]["avatar"], - "mxc://DUMMY_MEDIA_ID", - response_body["rooms"][room_id1], - ) - self.assertEqual( - response_body["rooms"][room_id1]["joined_count"], - # FIXME: The actual number should be "1" (user2) but we currently don't - # support this for rooms where the user has left/been banned. - 0, - ) - self.assertEqual( - response_body["rooms"][room_id1]["invited_count"], - 0, - ) - self.assertIsNone( - response_body["rooms"][room_id1].get("is_dm"), - ) - - def test_rooms_meta_heroes(self) -> None: - """ - Test that the `rooms` `heroes` are included in the response when the room - doesn't have a room name set. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - _user3_tok = self.login(user3_id, "pass") - - room_id1 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - "name": "my super room", - }, - ) - self.helper.join(room_id1, user1_id, tok=user1_tok) - # User3 is invited - self.helper.invite(room_id1, src=user2_id, targ=user3_id, tok=user2_tok) - - room_id2 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - # No room name set so that `heroes` is populated - # - # "name": "my super room2", - }, - ) - self.helper.join(room_id2, user1_id, tok=user1_tok) - # User3 is invited - self.helper.invite(room_id2, src=user2_id, targ=user3_id, tok=user2_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 0, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Room1 has a name so we shouldn't see any `heroes` which the client would use - # the calculate the room name themselves. - self.assertEqual( - response_body["rooms"][room_id1]["name"], - "my super room", - response_body["rooms"][room_id1], - ) - self.assertIsNone(response_body["rooms"][room_id1].get("heroes")) - self.assertEqual( - response_body["rooms"][room_id1]["joined_count"], - 2, - ) - self.assertEqual( - response_body["rooms"][room_id1]["invited_count"], - 1, - ) - - # Room2 doesn't have a name so we should see `heroes` populated - self.assertIsNone(response_body["rooms"][room_id2].get("name")) - self.assertCountEqual( - [ - hero["user_id"] - for hero in response_body["rooms"][room_id2].get("heroes", []) - ], - # Heroes shouldn't include the user themselves (we shouldn't see user1) - [user2_id, user3_id], - ) - self.assertEqual( - response_body["rooms"][room_id2]["joined_count"], - 2, - ) - self.assertEqual( - response_body["rooms"][room_id2]["invited_count"], - 1, - ) - - # We didn't request any state so we shouldn't see any `required_state` - self.assertIsNone(response_body["rooms"][room_id1].get("required_state")) - self.assertIsNone(response_body["rooms"][room_id2].get("required_state")) - - def test_rooms_meta_heroes_max(self) -> None: - """ - Test that the `rooms` `heroes` only includes the first 5 users (not including - yourself). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass") - user4_id = self.register_user("user4", "pass") - user4_tok = self.login(user4_id, "pass") - user5_id = self.register_user("user5", "pass") - user5_tok = self.login(user5_id, "pass") - user6_id = self.register_user("user6", "pass") - user6_tok = self.login(user6_id, "pass") - user7_id = self.register_user("user7", "pass") - user7_tok = self.login(user7_id, "pass") - - room_id1 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - # No room name set so that `heroes` is populated - # - # "name": "my super room", - }, - ) - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.join(room_id1, user3_id, tok=user3_tok) - self.helper.join(room_id1, user4_id, tok=user4_tok) - self.helper.join(room_id1, user5_id, tok=user5_tok) - self.helper.join(room_id1, user6_id, tok=user6_tok) - self.helper.join(room_id1, user7_id, tok=user7_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 0, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Room2 doesn't have a name so we should see `heroes` populated - self.assertIsNone(response_body["rooms"][room_id1].get("name")) - self.assertCountEqual( - [ - hero["user_id"] - for hero in response_body["rooms"][room_id1].get("heroes", []) - ], - # Heroes should be the first 5 users in the room (excluding the user - # themselves, we shouldn't see `user1`) - [user2_id, user3_id, user4_id, user5_id, user6_id], - ) - self.assertEqual( - response_body["rooms"][room_id1]["joined_count"], - 7, - ) - self.assertEqual( - response_body["rooms"][room_id1]["invited_count"], - 0, - ) - - # We didn't request any state so we shouldn't see any `required_state` - self.assertIsNone(response_body["rooms"][room_id1].get("required_state")) - - def test_rooms_meta_heroes_when_banned(self) -> None: - """ - Test that the `rooms` `heroes` are included in the response when the room - doesn't have a room name set but doesn't leak information past their ban. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - _user3_tok = self.login(user3_id, "pass") - user4_id = self.register_user("user4", "pass") - user4_tok = self.login(user4_id, "pass") - user5_id = self.register_user("user5", "pass") - _user5_tok = self.login(user5_id, "pass") - - room_id1 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - # No room name set so that `heroes` is populated - # - # "name": "my super room", - }, - ) - # User1 joins the room - self.helper.join(room_id1, user1_id, tok=user1_tok) - # User3 is invited - self.helper.invite(room_id1, src=user2_id, targ=user3_id, tok=user2_tok) - - # User1 is banned from the room - self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - - # User4 joins the room after user1 is banned - self.helper.join(room_id1, user4_id, tok=user4_tok) - # User5 is invited after user1 is banned - self.helper.invite(room_id1, src=user2_id, targ=user5_id, tok=user2_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 0, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Room2 doesn't have a name so we should see `heroes` populated - self.assertIsNone(response_body["rooms"][room_id1].get("name")) - self.assertCountEqual( - [ - hero["user_id"] - for hero in response_body["rooms"][room_id1].get("heroes", []) - ], - # Heroes shouldn't include the user themselves (we shouldn't see user1). We - # also shouldn't see user4 since they joined after user1 was banned. - # - # FIXME: The actual result should be `[user2_id, user3_id]` but we currently - # don't support this for rooms where the user has left/been banned. - [], - ) - - self.assertEqual( - response_body["rooms"][room_id1]["joined_count"], - # FIXME: The actual number should be "1" (user2) but we currently don't - # support this for rooms where the user has left/been banned. - 0, - ) - self.assertEqual( - response_body["rooms"][room_id1]["invited_count"], - # We shouldn't see user5 since they were invited after user1 was banned. - # - # FIXME: The actual number should be "1" (user3) but we currently don't - # support this for rooms where the user has left/been banned. - 0, - ) - - def test_rooms_bump_stamp(self) -> None: - """ - Test that `bump_stamp` is present and pointing to relevant events. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as( - user1_id, - tok=user1_tok, - ) - event_response1 = message_response = self.helper.send( - room_id1, "message in room1", tok=user1_tok - ) - event_pos1 = self.get_success( - self.store.get_position_for_event(event_response1["event_id"]) - ) - room_id2 = self.helper.create_room_as( - user1_id, - tok=user1_tok, - ) - send_response2 = self.helper.send(room_id2, "message in room2", tok=user1_tok) - event_pos2 = self.get_success( - self.store.get_position_for_event(send_response2["event_id"]) - ) - - # Send a reaction in room1 but it shouldn't affect the `bump_stamp` - # because reactions are not part of the `DEFAULT_BUMP_EVENT_TYPES` - self.helper.send_event( - room_id1, - type=EventTypes.Reaction, - content={ - "m.relates_to": { - "event_id": message_response["event_id"], - "key": "👍", - "rel_type": "m.annotation", - } - }, - tok=user1_tok, - ) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 100, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Make sure it has the foo-list we requested - self.assertListEqual( - list(response_body["lists"].keys()), - ["foo-list"], - response_body["lists"].keys(), - ) - - # Make sure the list includes the rooms in the right order - self.assertListEqual( - list(response_body["lists"]["foo-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 1], - # room1 sorts before room2 because it has the latest event (the - # reaction) - "room_ids": [room_id1, room_id2], - } - ], - response_body["lists"]["foo-list"], - ) - - # The `bump_stamp` for room1 should point at the latest message (not the - # reaction since it's not one of the `DEFAULT_BUMP_EVENT_TYPES`) - self.assertEqual( - response_body["rooms"][room_id1]["bump_stamp"], - event_pos1.stream, - response_body["rooms"][room_id1], - ) - - # The `bump_stamp` for room2 should point at the latest message - self.assertEqual( - response_body["rooms"][room_id2]["bump_stamp"], - event_pos2.stream, - response_body["rooms"][room_id2], - ) - - def test_rooms_bump_stamp_backfill(self) -> None: - """ - Test that `bump_stamp` ignores backfilled events, i.e. events with a - negative stream ordering. - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote room - creator = "@user:other" - room_id = "!foo:other" - shared_kwargs = { - "room_id": room_id, - "room_version": "10", - } - - create_tuple = self.get_success( - create_event( - self.hs, - prev_event_ids=[], - type=EventTypes.Create, - state_key="", - sender=creator, - **shared_kwargs, - ) - ) - creator_tuple = self.get_success( - create_event( - self.hs, - prev_event_ids=[create_tuple[0].event_id], - auth_event_ids=[create_tuple[0].event_id], - type=EventTypes.Member, - state_key=creator, - content={"membership": Membership.JOIN}, - sender=creator, - **shared_kwargs, - ) - ) - # We add a message event as a valid "bump type" - msg_tuple = self.get_success( - create_event( - self.hs, - prev_event_ids=[creator_tuple[0].event_id], - auth_event_ids=[create_tuple[0].event_id], - type=EventTypes.Message, - content={"body": "foo", "msgtype": "m.text"}, - sender=creator, - **shared_kwargs, - ) - ) - invite_tuple = self.get_success( - create_event( - self.hs, - prev_event_ids=[msg_tuple[0].event_id], - auth_event_ids=[create_tuple[0].event_id, creator_tuple[0].event_id], - type=EventTypes.Member, - state_key=user1_id, - content={"membership": Membership.INVITE}, - sender=creator, - **shared_kwargs, - ) - ) - - remote_events_and_contexts = [ - create_tuple, - creator_tuple, - msg_tuple, - invite_tuple, - ] - - # Ensure the local HS knows the room version - self.get_success( - self.store.store_room(room_id, creator, False, RoomVersions.V10) - ) - - # Persist these events as backfilled events. - persistence = self.hs.get_storage_controllers().persistence - assert persistence is not None - - for event, context in remote_events_and_contexts: - self.get_success(persistence.persist_event(event, context, backfilled=True)) - - # Now we join the local user to the room - join_tuple = self.get_success( - create_event( - self.hs, - prev_event_ids=[invite_tuple[0].event_id], - auth_event_ids=[create_tuple[0].event_id, invite_tuple[0].event_id], - type=EventTypes.Member, - state_key=user1_id, - content={"membership": Membership.JOIN}, - sender=user1_id, - **shared_kwargs, - ) - ) - self.get_success(persistence.persist_event(*join_tuple)) - - # Doing an SS request should return a positive `bump_stamp`, even though - # the only event that matches the bump types has as negative stream - # ordering. - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 5, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - self.assertGreater(response_body["rooms"][room_id]["bump_stamp"], 0) diff --git a/tests/rest/client/sliding_sync/test_rooms_required_state.py b/tests/rest/client/sliding_sync/test_rooms_required_state.py deleted file mode 100644
index a13cad223f..0000000000 --- a/tests/rest/client/sliding_sync/test_rooms_required_state.py +++ /dev/null
@@ -1,707 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging - -from parameterized import parameterized - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.api.constants import EventTypes, Membership -from synapse.handlers.sliding_sync import StateValues -from synapse.rest.client import login, room, sync -from synapse.server import HomeServer -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase -from tests.test_utils.event_injection import mark_event_as_partial_state - -logger = logging.getLogger(__name__) - - -class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase): - """ - Test `rooms.required_state` in the Sliding Sync API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - - def test_rooms_no_required_state(self) -> None: - """ - Empty `rooms.required_state` should not return any state events in the room - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - # Empty `required_state` - "required_state": [], - "timeline_limit": 0, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # No `required_state` in response - self.assertIsNone( - response_body["rooms"][room_id1].get("required_state"), - response_body["rooms"][room_id1], - ) - - def test_rooms_required_state_initial_sync(self) -> None: - """ - Test `rooms.required_state` returns requested state events in the room during an - initial sync. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Create, ""], - [EventTypes.RoomHistoryVisibility, ""], - # This one doesn't exist in the room - [EventTypes.Tombstone, ""], - ], - "timeline_limit": 0, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Create, "")], - state_map[(EventTypes.RoomHistoryVisibility, "")], - }, - exact=True, - ) - self.assertIsNone(response_body["rooms"][room_id1].get("invite_state")) - - def test_rooms_required_state_incremental_sync(self) -> None: - """ - Test `rooms.required_state` returns requested state events in the room during an - incremental sync. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Create, ""], - [EventTypes.RoomHistoryVisibility, ""], - # This one doesn't exist in the room - [EventTypes.Tombstone, ""], - ], - "timeline_limit": 1, - } - } - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Send a message so the room comes down sync. - self.helper.send(room_id1, "msg", tok=user1_tok) - - # Make the incremental Sliding Sync request - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # We only return updates but only if we've sent the room down the - # connection before. - self.assertIsNone(response_body["rooms"][room_id1].get("required_state")) - self.assertIsNone(response_body["rooms"][room_id1].get("invite_state")) - - def test_rooms_incremental_sync_restart(self) -> None: - """ - Test that after a restart (and so the in memory caches are reset) that - we correctly return an `M_UNKNOWN_POS` - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Create, ""], - [EventTypes.RoomHistoryVisibility, ""], - # This one doesn't exist in the room - [EventTypes.Tombstone, ""], - ], - "timeline_limit": 1, - } - } - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Reset the in-memory cache - self.hs.get_sliding_sync_handler().connection_store._connections.clear() - - # Make the Sliding Sync request - channel = self.make_request( - method="POST", - path=self.sync_endpoint + f"?pos={from_token}", - content=sync_body, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 400, channel.json_body) - self.assertEqual( - channel.json_body["errcode"], "M_UNKNOWN_POS", channel.json_body - ) - - def test_rooms_required_state_wildcard(self) -> None: - """ - Test `rooms.required_state` returns all state events when using wildcard `["*", "*"]`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - self.helper.send_state( - room_id1, - event_type="org.matrix.foo_state", - state_key="", - body={"foo": "bar"}, - tok=user2_tok, - ) - self.helper.send_state( - room_id1, - event_type="org.matrix.foo_state", - state_key="namespaced", - body={"foo": "bar"}, - tok=user2_tok, - ) - - # Make the Sliding Sync request with wildcards for the `event_type` and `state_key` - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [StateValues.WILDCARD, StateValues.WILDCARD], - ], - "timeline_limit": 0, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - # We should see all the state events in the room - state_map.values(), - exact=True, - ) - self.assertIsNone(response_body["rooms"][room_id1].get("invite_state")) - - def test_rooms_required_state_wildcard_event_type(self) -> None: - """ - Test `rooms.required_state` returns relevant state events when using wildcard in - the event_type `["*", "foobarbaz"]`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - self.helper.send_state( - room_id1, - event_type="org.matrix.foo_state", - state_key="", - body={"foo": "bar"}, - tok=user2_tok, - ) - self.helper.send_state( - room_id1, - event_type="org.matrix.foo_state", - state_key=user2_id, - body={"foo": "bar"}, - tok=user2_tok, - ) - - # Make the Sliding Sync request with wildcards for the `event_type` - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [StateValues.WILDCARD, user2_id], - ], - "timeline_limit": 0, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - # We expect at-least any state event with the `user2_id` as the `state_key` - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Member, user2_id)], - state_map[("org.matrix.foo_state", user2_id)], - }, - # Ideally, this would be exact but we're currently returning all state - # events when the `event_type` is a wildcard. - exact=False, - ) - self.assertIsNone(response_body["rooms"][room_id1].get("invite_state")) - - def test_rooms_required_state_wildcard_state_key(self) -> None: - """ - Test `rooms.required_state` returns relevant state events when using wildcard in - the state_key `["foobarbaz","*"]`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Make the Sliding Sync request with wildcards for the `state_key` - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Member, StateValues.WILDCARD], - ], - "timeline_limit": 0, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Member, user1_id)], - state_map[(EventTypes.Member, user2_id)], - }, - exact=True, - ) - self.assertIsNone(response_body["rooms"][room_id1].get("invite_state")) - - def test_rooms_required_state_lazy_loading_room_members(self) -> None: - """ - Test `rooms.required_state` returns people relevant to the timeline when - lazy-loading room members, `["m.room.member","$LAZY"]`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.join(room_id1, user3_id, tok=user3_tok) - - self.helper.send(room_id1, "1", tok=user2_tok) - self.helper.send(room_id1, "2", tok=user3_tok) - self.helper.send(room_id1, "3", tok=user2_tok) - - # Make the Sliding Sync request with lazy loading for the room members - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Create, ""], - [EventTypes.Member, StateValues.LAZY], - ], - "timeline_limit": 3, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - # Only user2 and user3 sent events in the 3 events we see in the `timeline` - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Create, "")], - state_map[(EventTypes.Member, user2_id)], - state_map[(EventTypes.Member, user3_id)], - }, - exact=True, - ) - self.assertIsNone(response_body["rooms"][room_id1].get("invite_state")) - - def test_rooms_required_state_me(self) -> None: - """ - Test `rooms.required_state` correctly handles $ME. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - self.helper.send(room_id1, "1", tok=user2_tok) - - # Also send normal state events with state keys of the users, first - # change the power levels to allow this. - self.helper.send_state( - room_id1, - event_type=EventTypes.PowerLevels, - body={"users": {user1_id: 50, user2_id: 100}}, - tok=user2_tok, - ) - self.helper.send_state( - room_id1, - event_type="org.matrix.foo", - state_key=user1_id, - body={}, - tok=user1_tok, - ) - self.helper.send_state( - room_id1, - event_type="org.matrix.foo", - state_key=user2_id, - body={}, - tok=user2_tok, - ) - - # Make the Sliding Sync request with a request for '$ME'. - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Create, ""], - [EventTypes.Member, StateValues.ME], - ["org.matrix.foo", StateValues.ME], - ], - "timeline_limit": 3, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - # Only user2 and user3 sent events in the 3 events we see in the `timeline` - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Create, "")], - state_map[(EventTypes.Member, user1_id)], - state_map[("org.matrix.foo", user1_id)], - }, - exact=True, - ) - self.assertIsNone(response_body["rooms"][room_id1].get("invite_state")) - - @parameterized.expand([(Membership.LEAVE,), (Membership.BAN,)]) - def test_rooms_required_state_leave_ban(self, stop_membership: str) -> None: - """ - Test `rooms.required_state` should not return state past a leave/ban event. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass") - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Create, ""], - [EventTypes.Member, "*"], - ["org.matrix.foo_state", ""], - ], - "timeline_limit": 3, - } - } - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.join(room_id1, user3_id, tok=user3_tok) - - self.helper.send_state( - room_id1, - event_type="org.matrix.foo_state", - state_key="", - body={"foo": "bar"}, - tok=user2_tok, - ) - - if stop_membership == Membership.LEAVE: - # User 1 leaves - self.helper.leave(room_id1, user1_id, tok=user1_tok) - elif stop_membership == Membership.BAN: - # User 1 is banned - self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - # Change the state after user 1 leaves - self.helper.send_state( - room_id1, - event_type="org.matrix.foo_state", - state_key="", - body={"foo": "qux"}, - tok=user2_tok, - ) - self.helper.leave(room_id1, user3_id, tok=user3_tok) - - # Make the Sliding Sync request with lazy loading for the room members - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # Only user2 and user3 sent events in the 3 events we see in the `timeline` - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Create, "")], - state_map[(EventTypes.Member, user1_id)], - state_map[(EventTypes.Member, user2_id)], - state_map[(EventTypes.Member, user3_id)], - state_map[("org.matrix.foo_state", "")], - }, - exact=True, - ) - self.assertIsNone(response_body["rooms"][room_id1].get("invite_state")) - - def test_rooms_required_state_combine_superset(self) -> None: - """ - Test `rooms.required_state` is combined across lists and room subscriptions. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - self.helper.send_state( - room_id1, - event_type="org.matrix.foo_state", - state_key="", - body={"foo": "bar"}, - tok=user2_tok, - ) - self.helper.send_state( - room_id1, - event_type="org.matrix.bar_state", - state_key="", - body={"bar": "qux"}, - tok=user2_tok, - ) - - # Make the Sliding Sync request with wildcards for the `state_key` - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Create, ""], - [EventTypes.Member, user1_id], - ], - "timeline_limit": 0, - }, - "bar-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Member, StateValues.WILDCARD], - ["org.matrix.foo_state", ""], - ], - "timeline_limit": 0, - }, - }, - "room_subscriptions": { - room_id1: { - "required_state": [["org.matrix.bar_state", ""]], - "timeline_limit": 0, - } - }, - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - state_map = self.get_success( - self.storage_controllers.state.get_current_state(room_id1) - ) - - self._assertRequiredStateIncludes( - response_body["rooms"][room_id1]["required_state"], - { - state_map[(EventTypes.Create, "")], - state_map[(EventTypes.Member, user1_id)], - state_map[(EventTypes.Member, user2_id)], - state_map[("org.matrix.foo_state", "")], - state_map[("org.matrix.bar_state", "")], - }, - exact=True, - ) - self.assertIsNone(response_body["rooms"][room_id1].get("invite_state")) - - def test_rooms_required_state_partial_state(self) -> None: - """ - Test partially-stated room are excluded unless `rooms.required_state` is - lazy-loading room members. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) - _join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok) - - # Mark room2 as partial state - self.get_success( - mark_event_as_partial_state(self.hs, join_response2["event_id"], room_id2) - ) - - # Make the Sliding Sync request (NOT lazy-loading room members) - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Create, ""], - ], - "timeline_limit": 0, - }, - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Make sure the list includes room1 but room2 is excluded because it's still - # partially-stated - self.assertListEqual( - list(response_body["lists"]["foo-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 1], - "room_ids": [room_id1], - } - ], - response_body["lists"]["foo-list"], - ) - - # Make the Sliding Sync request (with lazy-loading room members) - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [ - [EventTypes.Create, ""], - # Lazy-load room members - [EventTypes.Member, StateValues.LAZY], - ], - "timeline_limit": 0, - }, - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # The list should include both rooms now because we're lazy-loading room members - self.assertListEqual( - list(response_body["lists"]["foo-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 1], - "room_ids": [room_id2, room_id1], - } - ], - response_body["lists"]["foo-list"], - ) diff --git a/tests/rest/client/sliding_sync/test_rooms_timeline.py b/tests/rest/client/sliding_sync/test_rooms_timeline.py deleted file mode 100644
index 2e9586ca73..0000000000 --- a/tests/rest/client/sliding_sync/test_rooms_timeline.py +++ /dev/null
@@ -1,575 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging -from typing import List, Optional - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.rest.client import login, room, sync -from synapse.server import HomeServer -from synapse.types import StreamToken, StrSequence -from synapse.util import Clock - -from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase - -logger = logging.getLogger(__name__) - - -class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase): - """ - Test `rooms.timeline` in the Sliding Sync API. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.storage_controllers = hs.get_storage_controllers() - - def _assertListEqual( - self, - actual_items: StrSequence, - expected_items: StrSequence, - message: Optional[str] = None, - ) -> None: - """ - Like `self.assertListEqual(...)` but with an actually understandable diff message. - """ - - if actual_items == expected_items: - return - - expected_lines: List[str] = [] - for expected_item in expected_items: - is_expected_in_actual = expected_item in actual_items - expected_lines.append( - "{} {}".format(" " if is_expected_in_actual else "?", expected_item) - ) - - actual_lines: List[str] = [] - for actual_item in actual_items: - is_actual_in_expected = actual_item in expected_items - actual_lines.append( - "{} {}".format("+" if is_actual_in_expected else " ", actual_item) - ) - - newline = "\n" - expected_string = f"Expected items to be in actual ('?' = missing expected items):\n [\n{newline.join(expected_lines)}\n ]" - actual_string = f"Actual ('+' = found expected items):\n [\n{newline.join(actual_lines)}\n ]" - first_message = "Items must" - diff_message = f"{first_message}\n{expected_string}\n{actual_string}" - - self.fail(f"{diff_message}\n{message}") - - def _assertTimelineEqual( - self, - *, - room_id: str, - actual_event_ids: List[str], - expected_event_ids: List[str], - message: Optional[str] = None, - ) -> None: - """ - Like `self.assertListEqual(...)` for event IDs in a room but will give a nicer - output with context for what each event_id is (type, stream_ordering, content, - etc). - """ - if actual_event_ids == expected_event_ids: - return - - event_id_set = set(actual_event_ids + expected_event_ids) - events = self.get_success(self.store.get_events(event_id_set)) - - def event_id_to_string(event_id: str) -> str: - event = events.get(event_id) - if event: - state_key = event.get_state_key() - state_key_piece = f", {state_key}" if state_key is not None else "" - return ( - f"({event.internal_metadata.stream_ordering: >2}, {event.internal_metadata.instance_name}) " - + f"{event.event_id} ({event.type}{state_key_piece}) {event.content.get('membership', '')}{event.content.get('body', '')}" - ) - - return f"{event_id} <event not found in room_id={room_id}>" - - self._assertListEqual( - actual_items=[ - event_id_to_string(event_id) for event_id in actual_event_ids - ], - expected_items=[ - event_id_to_string(event_id) for event_id in expected_event_ids - ], - message=message, - ) - - def test_rooms_limited_initial_sync(self) -> None: - """ - Test that we mark `rooms` as `limited=True` when we saturate the `timeline_limit` - on initial sync. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.send(room_id1, "activity1", tok=user2_tok) - self.helper.send(room_id1, "activity2", tok=user2_tok) - event_response3 = self.helper.send(room_id1, "activity3", tok=user2_tok) - event_pos3 = self.get_success( - self.store.get_position_for_event(event_response3["event_id"]) - ) - event_response4 = self.helper.send(room_id1, "activity4", tok=user2_tok) - event_pos4 = self.get_success( - self.store.get_position_for_event(event_response4["event_id"]) - ) - event_response5 = self.helper.send(room_id1, "activity5", tok=user2_tok) - user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 3, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # We expect to saturate the `timeline_limit` (there are more than 3 messages in the room) - self.assertEqual( - response_body["rooms"][room_id1]["limited"], - True, - response_body["rooms"][room_id1], - ) - # Check to make sure the latest events are returned - self._assertTimelineEqual( - room_id=room_id1, - actual_event_ids=[ - event["event_id"] - for event in response_body["rooms"][room_id1]["timeline"] - ], - expected_event_ids=[ - event_response4["event_id"], - event_response5["event_id"], - user1_join_response["event_id"], - ], - message=str(response_body["rooms"][room_id1]["timeline"]), - ) - - # Check to make sure the `prev_batch` points at the right place - prev_batch_token = self.get_success( - StreamToken.from_string( - self.store, response_body["rooms"][room_id1]["prev_batch"] - ) - ) - prev_batch_room_stream_token_serialized = self.get_success( - prev_batch_token.room_key.to_string(self.store) - ) - # If we use the `prev_batch` token to look backwards, we should see `event3` - # next so make sure the token encompasses it - self.assertEqual( - event_pos3.persisted_after(prev_batch_token.room_key), - False, - f"`prev_batch` token {prev_batch_room_stream_token_serialized} should be >= event_pos3={self.get_success(event_pos3.to_room_stream_token().to_string(self.store))}", - ) - # If we use the `prev_batch` token to look backwards, we shouldn't see `event4` - # anymore since it was just returned in this response. - self.assertEqual( - event_pos4.persisted_after(prev_batch_token.room_key), - True, - f"`prev_batch` token {prev_batch_room_stream_token_serialized} should be < event_pos4={self.get_success(event_pos4.to_room_stream_token().to_string(self.store))}", - ) - - # With no `from_token` (initial sync), it's all historical since there is no - # "live" range - self.assertEqual( - response_body["rooms"][room_id1]["num_live"], - 0, - response_body["rooms"][room_id1], - ) - - def test_rooms_not_limited_initial_sync(self) -> None: - """ - Test that we mark `rooms` as `limited=False` when there are no more events to - paginate to. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.send(room_id1, "activity1", tok=user2_tok) - self.helper.send(room_id1, "activity2", tok=user2_tok) - self.helper.send(room_id1, "activity3", tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Make the Sliding Sync request - timeline_limit = 100 - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": timeline_limit, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # The timeline should be `limited=False` because we have all of the events (no - # more to paginate to) - self.assertEqual( - response_body["rooms"][room_id1]["limited"], - False, - response_body["rooms"][room_id1], - ) - expected_number_of_events = 9 - # We're just looking to make sure we got all of the events before hitting the `timeline_limit` - self.assertEqual( - len(response_body["rooms"][room_id1]["timeline"]), - expected_number_of_events, - response_body["rooms"][room_id1]["timeline"], - ) - self.assertLessEqual(expected_number_of_events, timeline_limit) - - # With no `from_token` (initial sync), it's all historical since there is no - # "live" token range. - self.assertEqual( - response_body["rooms"][room_id1]["num_live"], - 0, - response_body["rooms"][room_id1], - ) - - def test_rooms_incremental_sync(self) -> None: - """ - Test `rooms` data during an incremental sync after an initial sync. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.send(room_id1, "activity before initial sync1", tok=user2_tok) - - # Make an initial Sliding Sync request to grab a token. This is also a sanity - # check that we can go from initial to incremental sync. - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 3, - } - } - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Send some events but don't send enough to saturate the `timeline_limit`. - # We want to later test that we only get the new events since the `next_pos` - event_response2 = self.helper.send(room_id1, "activity after2", tok=user2_tok) - event_response3 = self.helper.send(room_id1, "activity after3", tok=user2_tok) - - # Make an incremental Sliding Sync request (what we're trying to test) - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # We only expect to see the new events since the last sync which isn't enough to - # fill up the `timeline_limit`. - self.assertEqual( - response_body["rooms"][room_id1]["limited"], - False, - f'Our `timeline_limit` was {sync_body["lists"]["foo-list"]["timeline_limit"]} ' - + f'and {len(response_body["rooms"][room_id1]["timeline"])} events were returned in the timeline. ' - + str(response_body["rooms"][room_id1]), - ) - # Check to make sure the latest events are returned - self._assertTimelineEqual( - room_id=room_id1, - actual_event_ids=[ - event["event_id"] - for event in response_body["rooms"][room_id1]["timeline"] - ], - expected_event_ids=[ - event_response2["event_id"], - event_response3["event_id"], - ], - message=str(response_body["rooms"][room_id1]["timeline"]), - ) - - # All events are "live" - self.assertEqual( - response_body["rooms"][room_id1]["num_live"], - 2, - response_body["rooms"][room_id1], - ) - - def test_rooms_newly_joined_incremental_sync(self) -> None: - """ - Test that when we make an incremental sync with a `newly_joined` `rooms`, we are - able to see some historical events before the `from_token`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.send(room_id1, "activity before token1", tok=user2_tok) - event_response2 = self.helper.send( - room_id1, "activity before token2", tok=user2_tok - ) - - # The `timeline_limit` is set to 4 so we can at least see one historical event - # before the `from_token`. We should see historical events because this is a - # `newly_joined` room. - timeline_limit = 4 - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": timeline_limit, - } - } - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Join the room after the `from_token` which will make us consider this room as - # `newly_joined`. - user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - - # Send some events but don't send enough to saturate the `timeline_limit`. - # We want to later test that we only get the new events since the `next_pos` - event_response3 = self.helper.send( - room_id1, "activity after token3", tok=user2_tok - ) - event_response4 = self.helper.send( - room_id1, "activity after token4", tok=user2_tok - ) - - # Make an incremental Sliding Sync request (what we're trying to test) - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # We should see the new events and the rest should be filled with historical - # events which will make us `limited=True` since there are more to paginate to. - self.assertEqual( - response_body["rooms"][room_id1]["limited"], - True, - f"Our `timeline_limit` was {timeline_limit} " - + f'and {len(response_body["rooms"][room_id1]["timeline"])} events were returned in the timeline. ' - + str(response_body["rooms"][room_id1]), - ) - # Check to make sure that the "live" and historical events are returned - self._assertTimelineEqual( - room_id=room_id1, - actual_event_ids=[ - event["event_id"] - for event in response_body["rooms"][room_id1]["timeline"] - ], - expected_event_ids=[ - event_response2["event_id"], - user1_join_response["event_id"], - event_response3["event_id"], - event_response4["event_id"], - ], - message=str(response_body["rooms"][room_id1]["timeline"]), - ) - - # Only events after the `from_token` are "live" (join, event3, event4) - self.assertEqual( - response_body["rooms"][room_id1]["num_live"], - 3, - response_body["rooms"][room_id1], - ) - - def test_rooms_ban_initial_sync(self) -> None: - """ - Test that `rooms` we are banned from in an intial sync only allows us to see - timeline events up to the ban event. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.send(room_id1, "activity before1", tok=user2_tok) - self.helper.send(room_id1, "activity before2", tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - event_response3 = self.helper.send(room_id1, "activity after3", tok=user2_tok) - event_response4 = self.helper.send(room_id1, "activity after4", tok=user2_tok) - user1_ban_response = self.helper.ban( - room_id1, src=user2_id, targ=user1_id, tok=user2_tok - ) - - self.helper.send(room_id1, "activity after5", tok=user2_tok) - self.helper.send(room_id1, "activity after6", tok=user2_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 3, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # We should see events before the ban but not after - self._assertTimelineEqual( - room_id=room_id1, - actual_event_ids=[ - event["event_id"] - for event in response_body["rooms"][room_id1]["timeline"] - ], - expected_event_ids=[ - event_response3["event_id"], - event_response4["event_id"], - user1_ban_response["event_id"], - ], - message=str(response_body["rooms"][room_id1]["timeline"]), - ) - # No "live" events in an initial sync (no `from_token` to define the "live" - # range) - self.assertEqual( - response_body["rooms"][room_id1]["num_live"], - 0, - response_body["rooms"][room_id1], - ) - # There are more events to paginate to - self.assertEqual( - response_body["rooms"][room_id1]["limited"], - True, - response_body["rooms"][room_id1], - ) - - def test_rooms_ban_incremental_sync1(self) -> None: - """ - Test that `rooms` we are banned from during the next incremental sync only - allows us to see timeline events up to the ban event. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.send(room_id1, "activity before1", tok=user2_tok) - self.helper.send(room_id1, "activity before2", tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 4, - } - } - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - event_response3 = self.helper.send(room_id1, "activity after3", tok=user2_tok) - event_response4 = self.helper.send(room_id1, "activity after4", tok=user2_tok) - # The ban is within the token range (between the `from_token` and the sliding - # sync request) - user1_ban_response = self.helper.ban( - room_id1, src=user2_id, targ=user1_id, tok=user2_tok - ) - - self.helper.send(room_id1, "activity after5", tok=user2_tok) - self.helper.send(room_id1, "activity after6", tok=user2_tok) - - # Make the incremental Sliding Sync request - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # We should see events before the ban but not after - self._assertTimelineEqual( - room_id=room_id1, - actual_event_ids=[ - event["event_id"] - for event in response_body["rooms"][room_id1]["timeline"] - ], - expected_event_ids=[ - event_response3["event_id"], - event_response4["event_id"], - user1_ban_response["event_id"], - ], - message=str(response_body["rooms"][room_id1]["timeline"]), - ) - # All live events in the incremental sync - self.assertEqual( - response_body["rooms"][room_id1]["num_live"], - 3, - response_body["rooms"][room_id1], - ) - # There aren't anymore events to paginate to in this range - self.assertEqual( - response_body["rooms"][room_id1]["limited"], - False, - response_body["rooms"][room_id1], - ) - - def test_rooms_ban_incremental_sync2(self) -> None: - """ - Test that `rooms` we are banned from before the incremental sync don't return - any events in the timeline. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.send(room_id1, "activity before1", tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - - self.helper.send(room_id1, "activity after2", tok=user2_tok) - # The ban is before we get our `from_token` - self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - - self.helper.send(room_id1, "activity after3", tok=user2_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 4, - } - } - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - self.helper.send(room_id1, "activity after4", tok=user2_tok) - - # Make the incremental Sliding Sync request - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # Nothing to see for this banned user in the room in the token range - self.assertIsNone(response_body["rooms"].get(room_id1)) diff --git a/tests/rest/client/sliding_sync/test_sliding_sync.py b/tests/rest/client/sliding_sync/test_sliding_sync.py deleted file mode 100644
index cb7638c5ba..0000000000 --- a/tests/rest/client/sliding_sync/test_sliding_sync.py +++ /dev/null
@@ -1,974 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -import logging -from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple - -from typing_extensions import assert_never - -from twisted.test.proto_helpers import MemoryReactor - -import synapse.rest.admin -from synapse.api.constants import ( - AccountDataTypes, - EventContentFields, - EventTypes, - RoomTypes, -) -from synapse.events import EventBase -from synapse.rest.client import devices, login, receipts, room, sync -from synapse.server import HomeServer -from synapse.types import ( - JsonDict, - RoomStreamToken, - SlidingSyncStreamToken, - StreamKeyType, - StreamToken, -) -from synapse.util import Clock -from synapse.util.stringutils import random_string - -from tests import unittest -from tests.server import TimedOutException - -logger = logging.getLogger(__name__) - - -class SlidingSyncBase(unittest.HomeserverTestCase): - """Base class for sliding sync test cases""" - - sync_endpoint = "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync" - - def default_config(self) -> JsonDict: - config = super().default_config() - # Enable sliding sync - config["experimental_features"] = {"msc3575_enabled": True} - return config - - def do_sync( - self, sync_body: JsonDict, *, since: Optional[str] = None, tok: str - ) -> Tuple[JsonDict, str]: - """Do a sliding sync request with given body. - - Asserts the request was successful. - - Attributes: - sync_body: The full request body to use - since: Optional since token - tok: Access token to use - - Returns: - A tuple of the response body and the `pos` field. - """ - - sync_path = self.sync_endpoint - if since: - sync_path += f"?pos={since}" - - channel = self.make_request( - method="POST", - path=sync_path, - content=sync_body, - access_token=tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - return channel.json_body, channel.json_body["pos"] - - def _assertRequiredStateIncludes( - self, - actual_required_state: Any, - expected_state_events: Iterable[EventBase], - exact: bool = False, - ) -> None: - """ - Wrapper around `assertIncludes` to give slightly better looking diff error - messages that include some context "$event_id (type, state_key)". - - Args: - actual_required_state: The "required_state" of a room from a Sliding Sync - request response. - expected_state_events: The expected state events to be included in the - `actual_required_state`. - exact: Whether the actual state should be exactly equal to the expected - state (no extras). - """ - - assert isinstance(actual_required_state, list) - for event in actual_required_state: - assert isinstance(event, dict) - - self.assertIncludes( - { - f'{event["event_id"]} ("{event["type"]}", "{event["state_key"]}")' - for event in actual_required_state - }, - { - f'{event.event_id} ("{event.type}", "{event.state_key}")' - for event in expected_state_events - }, - exact=exact, - # Message to help understand the diff in context - message=str(actual_required_state), - ) - - def _bump_notifier_wait_for_events( - self, - user_id: str, - wake_stream_key: Literal[ - StreamKeyType.ACCOUNT_DATA, - StreamKeyType.PRESENCE, - ], - ) -> None: - """ - Wake-up a `notifier.wait_for_events(user_id)` call without affecting the Sliding - Sync results. - - Args: - user_id: The user ID to wake up the notifier for - wake_stream_key: The stream key to wake up. This will create an actual new - entity in that stream so it's best to choose one that won't affect the - Sliding Sync results you're testing for. In other words, if your testing - account data, choose `StreamKeyType.PRESENCE` instead. We support two - possible stream keys because you're probably testing one or the other so - one is always a "safe" option. - """ - # We're expecting some new activity from this point onwards - from_token = self.hs.get_event_sources().get_current_token() - - triggered_notifier_wait_for_events = False - - async def _on_new_acivity( - before_token: StreamToken, after_token: StreamToken - ) -> bool: - nonlocal triggered_notifier_wait_for_events - triggered_notifier_wait_for_events = True - return True - - notifier = self.hs.get_notifier() - - # Listen for some new activity for the user. We're just trying to confirm that - # our bump below actually does what we think it does (triggers new activity for - # the user). - result_awaitable = notifier.wait_for_events( - user_id, - 1000, - _on_new_acivity, - from_token=from_token, - ) - - # Update the account data or presence so that `notifier.wait_for_events(...)` - # wakes up. We chose these two options because they're least likely to show up - # in the Sliding Sync response so it won't affect whether we have results. - if wake_stream_key == StreamKeyType.ACCOUNT_DATA: - self.get_success( - self.hs.get_account_data_handler().add_account_data_for_user( - user_id, - "org.matrix.foobarbaz", - {"foo": "bar"}, - ) - ) - elif wake_stream_key == StreamKeyType.PRESENCE: - sending_user_id = self.register_user( - "user_bump_notifier_wait_for_events_" + random_string(10), "pass" - ) - sending_user_tok = self.login(sending_user_id, "pass") - test_msg = {"foo": "bar"} - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.test/1234", - content={"messages": {user_id: {"d1": test_msg}}}, - access_token=sending_user_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - else: - assert_never(wake_stream_key) - - # Wait for our notifier result - self.get_success(result_awaitable) - - if not triggered_notifier_wait_for_events: - raise AssertionError( - "Expected `notifier.wait_for_events(...)` to be triggered" - ) - - -class SlidingSyncTestCase(SlidingSyncBase): - """ - Tests regarding MSC3575 Sliding Sync `/sync` endpoint. - - Please put tests in more specific test files if applicable. This test class is meant - for generic behavior of the endpoint. - """ - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - devices.register_servlets, - receipts.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.event_sources = hs.get_event_sources() - self.storage_controllers = hs.get_storage_controllers() - self.account_data_handler = hs.get_account_data_handler() - - def _add_new_dm_to_global_account_data( - self, source_user_id: str, target_user_id: str, target_room_id: str - ) -> None: - """ - Helper to handle inserting a new DM for the source user into global account data - (handles all of the list merging). - - Args: - source_user_id: The user ID of the DM mapping we're going to update - target_user_id: User ID of the person the DM is with - target_room_id: Room ID of the DM - """ - - # Get the current DM map - existing_dm_map = self.get_success( - self.store.get_global_account_data_by_type_for_user( - source_user_id, AccountDataTypes.DIRECT - ) - ) - # Scrutinize the account data since it has no concrete type. We're just copying - # everything into a known type. It should be a mapping from user ID to a list of - # room IDs. Ignore anything else. - new_dm_map: Dict[str, List[str]] = {} - if isinstance(existing_dm_map, dict): - for user_id, room_ids in existing_dm_map.items(): - if isinstance(user_id, str) and isinstance(room_ids, list): - for room_id in room_ids: - if isinstance(room_id, str): - new_dm_map[user_id] = new_dm_map.get(user_id, []) + [ - room_id - ] - - # Add the new DM to the map - new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [ - target_room_id - ] - # Save the DM map to global account data - self.get_success( - self.store.add_account_data_for_user( - source_user_id, - AccountDataTypes.DIRECT, - new_dm_map, - ) - ) - - def _create_dm_room( - self, - inviter_user_id: str, - inviter_tok: str, - invitee_user_id: str, - invitee_tok: str, - should_join_room: bool = True, - ) -> str: - """ - Helper to create a DM room as the "inviter" and invite the "invitee" user to the - room. The "invitee" user also will join the room. The `m.direct` account data - will be set for both users. - """ - - # Create a room and send an invite the other user - room_id = self.helper.create_room_as( - inviter_user_id, - is_public=False, - tok=inviter_tok, - ) - self.helper.invite( - room_id, - src=inviter_user_id, - targ=invitee_user_id, - tok=inviter_tok, - extra_data={"is_direct": True}, - ) - if should_join_room: - # Person that was invited joins the room - self.helper.join(room_id, invitee_user_id, tok=invitee_tok) - - # Mimic the client setting the room as a direct message in the global account - # data for both users. - self._add_new_dm_to_global_account_data( - invitee_user_id, inviter_user_id, room_id - ) - self._add_new_dm_to_global_account_data( - inviter_user_id, invitee_user_id, room_id - ) - - return room_id - - def test_sync_list(self) -> None: - """ - Test that room IDs show up in the Sliding Sync `lists` - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 99]], - "required_state": [], - "timeline_limit": 1, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Make sure it has the foo-list we requested - self.assertListEqual( - list(response_body["lists"].keys()), - ["foo-list"], - response_body["lists"].keys(), - ) - - # Make sure the list includes the room we are joined to - self.assertListEqual( - list(response_body["lists"]["foo-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 99], - "room_ids": [room_id], - } - ], - response_body["lists"]["foo-list"], - ) - - def test_wait_for_sync_token(self) -> None: - """ - Test that worker will wait until it catches up to the given token - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a future token that will cause us to wait. Since we never send a new - # event to reach that future stream_ordering, the worker will wait until the - # full timeout. - stream_id_gen = self.store.get_events_stream_id_generator() - stream_id = self.get_success(stream_id_gen.get_next().__aenter__()) - current_token = self.event_sources.get_current_token() - future_position_token = current_token.copy_and_replace( - StreamKeyType.ROOM, - RoomStreamToken(stream=stream_id), - ) - - future_position_token_serialized = self.get_success( - SlidingSyncStreamToken(future_position_token, 0).to_string(self.store) - ) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 99]], - "required_state": [], - "timeline_limit": 1, - } - } - } - channel = self.make_request( - "POST", - self.sync_endpoint + f"?pos={future_position_token_serialized}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 10 seconds to make `notifier.wait_for_stream_token(from_token)` - # timeout - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=9900) - channel.await_result(timeout_ms=200) - self.assertEqual(channel.code, 200, channel.json_body) - - # We expect the next `pos` in the result to be the same as what we requested - # with because we weren't able to find anything new yet. - self.assertEqual(channel.json_body["pos"], future_position_token_serialized) - - def test_wait_for_new_data(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive. - - (Only applies to incremental syncs with a `timeout` specified) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id, user1_id, tok=user1_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 0]], - "required_state": [], - "timeline_limit": 1, - } - } - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make the Sliding Sync request - channel = self.make_request( - "POST", - self.sync_endpoint + f"?timeout=10000&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Bump the room with new events to trigger new results - event_response1 = self.helper.send( - room_id, "new activity in room", tok=user1_tok - ) - # Should respond before the 10 second timeout - channel.await_result(timeout_ms=3000) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check to make sure the new event is returned - self.assertEqual( - [ - event["event_id"] - for event in channel.json_body["rooms"][room_id]["timeline"] - ], - [ - event_response1["event_id"], - ], - channel.json_body["rooms"][room_id]["timeline"], - ) - - def test_wait_for_new_data_timeout(self) -> None: - """ - Test to make sure that the Sliding Sync request waits for new data to arrive but - no data ever arrives so we timeout. We're also making sure that the default data - doesn't trigger a false-positive for new data. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id, user1_id, tok=user1_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 0]], - "required_state": [], - "timeline_limit": 1, - } - } - } - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make the Sliding Sync request - channel = self.make_request( - "POST", - self.sync_endpoint + f"?timeout=10000&pos={from_token}", - content=sync_body, - access_token=user1_tok, - await_result=False, - ) - # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=5000) - # Wake-up `notifier.wait_for_events(...)` that will cause us test - # `SlidingSyncResult.__bool__` for new results. - self._bump_notifier_wait_for_events( - user1_id, wake_stream_key=StreamKeyType.ACCOUNT_DATA - ) - # Block for a little bit more to ensure we don't see any new results. - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=4000) - # Wait for the sync to complete (wait for the rest of the 10 second timeout, - # 5000 + 4000 + 1200 > 10000) - channel.await_result(timeout_ms=1200) - self.assertEqual(channel.code, 200, channel.json_body) - - # There should be no room sent down. - self.assertFalse(channel.json_body["rooms"]) - - def test_filter_list(self) -> None: - """ - Test that filters apply to `lists` - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Create a DM room - joined_dm_room_id = self._create_dm_room( - inviter_user_id=user1_id, - inviter_tok=user1_tok, - invitee_user_id=user2_id, - invitee_tok=user2_tok, - should_join_room=True, - ) - invited_dm_room_id = self._create_dm_room( - inviter_user_id=user1_id, - inviter_tok=user1_tok, - invitee_user_id=user2_id, - invitee_tok=user2_tok, - should_join_room=False, - ) - - # Create a normal room - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id, user1_id, tok=user1_tok) - - # Create a room that user1 is invited to - invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - # Absense of filters does not imply "False" values - "all": { - "ranges": [[0, 99]], - "required_state": [], - "timeline_limit": 1, - "filters": {}, - }, - # Test single truthy filter - "dms": { - "ranges": [[0, 99]], - "required_state": [], - "timeline_limit": 1, - "filters": {"is_dm": True}, - }, - # Test single falsy filter - "non-dms": { - "ranges": [[0, 99]], - "required_state": [], - "timeline_limit": 1, - "filters": {"is_dm": False}, - }, - # Test how multiple filters should stack (AND'd together) - "room-invites": { - "ranges": [[0, 99]], - "required_state": [], - "timeline_limit": 1, - "filters": {"is_dm": False, "is_invite": True}, - }, - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Make sure it has the foo-list we requested - self.assertListEqual( - list(response_body["lists"].keys()), - ["all", "dms", "non-dms", "room-invites"], - response_body["lists"].keys(), - ) - - # Make sure the lists have the correct rooms - self.assertListEqual( - list(response_body["lists"]["all"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 99], - "room_ids": [ - invite_room_id, - room_id, - invited_dm_room_id, - joined_dm_room_id, - ], - } - ], - list(response_body["lists"]["all"]), - ) - self.assertListEqual( - list(response_body["lists"]["dms"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 99], - "room_ids": [invited_dm_room_id, joined_dm_room_id], - } - ], - list(response_body["lists"]["dms"]), - ) - self.assertListEqual( - list(response_body["lists"]["non-dms"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 99], - "room_ids": [invite_room_id, room_id], - } - ], - list(response_body["lists"]["non-dms"]), - ) - self.assertListEqual( - list(response_body["lists"]["room-invites"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 99], - "room_ids": [invite_room_id], - } - ], - list(response_body["lists"]["room-invites"]), - ) - - # Ensure DM's are correctly marked - self.assertDictEqual( - { - room_id: room.get("is_dm") - for room_id, room in response_body["rooms"].items() - }, - { - invite_room_id: None, - room_id: None, - invited_dm_room_id: True, - joined_dm_room_id: True, - }, - ) - - def test_filter_regardless_of_membership_server_left_room(self) -> None: - """ - Test that filters apply to rooms regardless of membership. We're also - compounding the problem by having all of the local users leave the room causing - our server to leave the room. - - We want to make sure that if someone is filtering rooms, and leaves, you still - get that final update down sync that you left. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Create a normal room - room_id = self.helper.create_room_as(user1_id, tok=user2_tok) - self.helper.join(room_id, user1_id, tok=user1_tok) - - # Create an encrypted space room - space_room_id = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - self.helper.send_state( - space_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user2_tok, - ) - self.helper.join(space_room_id, user1_id, tok=user1_tok) - - # Make an initial Sliding Sync request - channel = self.make_request( - "POST", - self.sync_endpoint, - { - "lists": { - "all-list": { - "ranges": [[0, 99]], - "required_state": [], - "timeline_limit": 0, - "filters": {}, - }, - "foo-list": { - "ranges": [[0, 99]], - "required_state": [], - "timeline_limit": 1, - "filters": { - "is_encrypted": True, - "room_types": [RoomTypes.SPACE], - }, - }, - } - }, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - from_token = channel.json_body["pos"] - - # Make sure the response has the lists we requested - self.assertListEqual( - list(channel.json_body["lists"].keys()), - ["all-list", "foo-list"], - channel.json_body["lists"].keys(), - ) - - # Make sure the lists have the correct rooms - self.assertListEqual( - list(channel.json_body["lists"]["all-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 99], - "room_ids": [space_room_id, room_id], - } - ], - ) - self.assertListEqual( - list(channel.json_body["lists"]["foo-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 99], - "room_ids": [space_room_id], - } - ], - ) - - # Everyone leaves the encrypted space room - self.helper.leave(space_room_id, user1_id, tok=user1_tok) - self.helper.leave(space_room_id, user2_id, tok=user2_tok) - - # Make an incremental Sliding Sync request - channel = self.make_request( - "POST", - self.sync_endpoint + f"?pos={from_token}", - { - "lists": { - "all-list": { - "ranges": [[0, 99]], - "required_state": [], - "timeline_limit": 0, - "filters": {}, - }, - "foo-list": { - "ranges": [[0, 99]], - "required_state": [], - "timeline_limit": 1, - "filters": { - "is_encrypted": True, - "room_types": [RoomTypes.SPACE], - }, - }, - } - }, - access_token=user1_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Make sure the lists have the correct rooms even though we `newly_left` - self.assertListEqual( - list(channel.json_body["lists"]["all-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 99], - "room_ids": [space_room_id, room_id], - } - ], - ) - self.assertListEqual( - list(channel.json_body["lists"]["foo-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 99], - "room_ids": [space_room_id], - } - ], - ) - - def test_sort_list(self) -> None: - """ - Test that the `lists` are sorted by `stream_ordering` - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - room_id3 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - - # Activity that will order the rooms - self.helper.send(room_id3, "activity in room3", tok=user1_tok) - self.helper.send(room_id1, "activity in room1", tok=user1_tok) - self.helper.send(room_id2, "activity in room2", tok=user1_tok) - - # Make the Sliding Sync request - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 99]], - "required_state": [], - "timeline_limit": 1, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Make sure it has the foo-list we requested - self.assertListEqual( - list(response_body["lists"].keys()), - ["foo-list"], - response_body["lists"].keys(), - ) - - # Make sure the list is sorted in the way we expect - self.assertListEqual( - list(response_body["lists"]["foo-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 99], - "room_ids": [room_id2, room_id1, room_id3], - } - ], - response_body["lists"]["foo-list"], - ) - - def test_sliced_windows(self) -> None: - """ - Test that the `lists` `ranges` are sliced correctly. Both sides of each range - are inclusive. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - _room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - room_id3 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - - # Make the Sliding Sync request for a single room - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 0]], - "required_state": [], - "timeline_limit": 1, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Make sure it has the foo-list we requested - self.assertListEqual( - list(response_body["lists"].keys()), - ["foo-list"], - response_body["lists"].keys(), - ) - # Make sure the list is sorted in the way we expect - self.assertListEqual( - list(response_body["lists"]["foo-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 0], - "room_ids": [room_id3], - } - ], - response_body["lists"]["foo-list"], - ) - - # Make the Sliding Sync request for the first two rooms - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 1, - } - } - } - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - - # Make sure it has the foo-list we requested - self.assertListEqual( - list(response_body["lists"].keys()), - ["foo-list"], - response_body["lists"].keys(), - ) - # Make sure the list is sorted in the way we expect - self.assertListEqual( - list(response_body["lists"]["foo-list"]["ops"]), - [ - { - "op": "SYNC", - "range": [0, 1], - "room_ids": [room_id3, room_id2], - } - ], - response_body["lists"]["foo-list"], - ) - - def test_rooms_with_no_updates_do_not_come_down_incremental_sync(self) -> None: - """ - Test that rooms with no updates are returned in subsequent incremental - syncs. - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 0, - } - } - } - - _, from_token = self.do_sync(sync_body, tok=user1_tok) - - # Make the incremental Sliding Sync request - response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok) - - # Nothing has happened in the room, so the room should not come down - # /sync. - self.assertIsNone(response_body["rooms"].get(room_id1)) - - def test_empty_initial_room_comes_down_sync(self) -> None: - """ - Test that rooms come down /sync even with empty required state and - timeline limit in initial sync. - """ - - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - - sync_body = { - "lists": { - "foo-list": { - "ranges": [[0, 1]], - "required_state": [], - "timeline_limit": 0, - } - } - } - - # Make the Sliding Sync request - response_body, _ = self.do_sync(sync_body, tok=user1_tok) - self.assertEqual(response_body["rooms"][room_id1]["initial"], True) diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index a85ea994de..f1e4bdea76 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -427,23 +426,13 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): text = None for part in mail.walk(): if part.get_content_type() == "text/plain": - text = part.get_payload(decode=True) - if text is not None: - # According to the logic table in `get_payload`, we know that - # the result of `get_payload` will be `bytes`, but mypy doesn't - # know this and complains. Thus, we assert the type. - assert isinstance(text, bytes) - text = text.decode("UTF-8") - + text = part.get_payload(decode=True).decode("UTF-8") break if not text: self.fail("Could not find text portion of email to parse") - # `text` must be a `str`, after being decoded and determined just above - # to not be `None` or an empty `str`. - assert isinstance(text, str) - + assert text is not None match = re.search(r"https://example.com\S+", text) assert match, "Could not find link in email" @@ -1219,23 +1208,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): text = None for part in mail.walk(): if part.get_content_type() == "text/plain": - text = part.get_payload(decode=True) - if text is not None: - # According to the logic table in `get_payload`, we know that - # the result of `get_payload` will be `bytes`, but mypy doesn't - # know this and complains. Thus, we assert the type. - assert isinstance(text, bytes) - text = text.decode("UTF-8") - + text = part.get_payload(decode=True).decode("UTF-8") break if not text: self.fail("Could not find text portion of email to parse") - # `text` must be a `str`, after being decoded and determined just above - # to not be `None` or an empty `str`. - assert isinstance(text, str) - + assert text is not None match = re.search(r"https://example.com\S+", text) assert match, "Could not find link in email" diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py
index be6d7af2fc..ce505eef62 100644 --- a/tests/rest/client/test_account_data.py +++ b/tests/rest/client/test_account_data.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 0b5daf4bb4..8bc36209e5 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020-2021 The Matrix.org Foundation C.I.C # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index a3ed12a38f..f1cfe0df91 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -24,8 +23,8 @@ from twisted.internet.defer import ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError -from synapse.rest import admin, devices, sync -from synapse.rest.client import keys, login, register +from synapse.rest import admin, devices, room, sync +from synapse.rest.client import account, keys, login, register from synapse.server import HomeServer from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock @@ -33,6 +32,146 @@ from synapse.util import Clock from tests import unittest +class DeviceListsTestCase(unittest.HomeserverTestCase): + """Tests regarding device list changes.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + account.register_servlets, + room.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def test_receiving_local_device_list_changes(self) -> None: + """Tests that a local users that share a room receive each other's device list + changes. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # Create a room for them to coexist peacefully in + new_room_id = self.helper.create_room_as( + alice_user_id, is_public=True, tok=alice_access_token + ) + self.assertIsNotNone(new_room_id) + + # Have Bob join the room + self.helper.invite( + new_room_id, alice_user_id, bob_user_id, tok=alice_access_token + ) + self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) + + # Now have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=30000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check that bob's incremental sync contains the updated device list. + # If not, the client would only receive the device list update on the + # *next* sync. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) + + def test_not_receiving_local_device_list_changes(self) -> None: + """Tests a local users DO NOT receive device updates from each other if they do not + share a room. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # These users do not share a room. They are lonely. + + # Have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=1000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # Check that bob's incremental sync does not contain the updated device list. + bob_sync_channel.await_result() + self.assertEqual( + bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body + ) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertNotIn( + alice_user_id, changed_device_lists, bob_sync_channel.json_body + ) + + class DevicesTestCase(unittest.HomeserverTestCase): servlets = [ admin.register_servlets, diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index 06f1c1b234..c456b24839 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 9cfc6b224f..b3b108dd5e 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -72,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False # type: ignore[assignment] + self.hs.is_mine = lambda target_user: False # type: ignore[method-assign] channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index 8bbd109092..e99160c5ac 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -155,6 +154,71 @@ class KeyQueryTestCase(unittest.HomeserverTestCase): } def test_device_signing_with_uia(self) -> None: + """Device signing key upload requires UIA.""" + password = "wonderland" + device_id = "ABCDEFGHI" + alice_id = self.register_user("alice", password) + alice_token = self.login("alice", password, device_id=device_id) + + content = self.make_device_keys(alice_id, device_id) + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + content, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # add UI auth + content["auth"] = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": alice_id}, + "password": password, + "session": session, + } + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + content, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + @override_config({"ui_auth": {"session_timeout": "15m"}}) + def test_device_signing_with_uia_session_timeout(self) -> None: + """Device signing key upload requires UIA buy passes with grace period.""" + password = "wonderland" + device_id = "ABCDEFGHI" + alice_id = self.register_user("alice", password) + alice_token = self.login("alice", password, device_id=device_id) + + content = self.make_device_keys(alice_id, device_id) + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + content, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + @override_config( + { + "experimental_features": {"msc3967_enabled": True}, + "ui_auth": {"session_timeout": "15s"}, + } + ) + def test_device_signing_with_msc3967(self) -> None: + """Device signing key follows MSC3967 behaviour when enabled.""" password = "wonderland" device_id = "ABCDEFGHI" alice_id = self.register_user("alice", password) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 2b1e44381b..42610308d5 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -20,17 +19,7 @@ # import time import urllib.parse -from typing import ( - Any, - BinaryIO, - Callable, - Collection, - Dict, - List, - Optional, - Tuple, - Union, -) +from typing import Any, Collection, Dict, List, Optional, Tuple, Union from unittest.mock import Mock from urllib.parse import urlencode @@ -44,9 +33,8 @@ import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService -from synapse.http.client import RawHeaders from synapse.module_api import ModuleApi -from synapse.rest.client import account, devices, login, logout, profile, register +from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.server import HomeServer @@ -59,7 +47,6 @@ from tests.handlers.test_saml import has_saml2 from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser -from tests.test_utils.oidc import FakeOidcServer from tests.unittest import HomeserverTestCase, override_config, skip_unless try: @@ -189,6 +176,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # rc_login dict here, we need to set this manually as well "account": {"per_second": 10000, "burst_count": 10000}, }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_address(self) -> None: @@ -240,6 +228,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # rc_login dict here, we need to set this manually as well "address": {"per_second": 10000, "burst_count": 10000}, }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_account(self) -> None: @@ -288,6 +277,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "address": {"per_second": 10000, "burst_count": 10000}, "failed_attempts": {"per_second": 0.17, "burst_count": 5}, }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: @@ -969,8 +959,9 @@ class CASTestCase(unittest.HomeserverTestCase): # Test that the response is HTML. self.assertEqual(channel.code, 200, channel.result) content_type_header_value = "" - for header in channel.headers.getRawHeaders("Content-Type", []): - content_type_header_value = header + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type_header_value = header[1].decode("utf8") self.assertTrue(content_type_header_value.startswith("text/html")) @@ -1432,19 +1423,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): class UsernamePickerTestCase(HomeserverTestCase): """Tests for the username picker flow of SSO login""" - servlets = [ - login.register_servlets, - profile.register_servlets, - account.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.http_client = Mock(spec=["get_file"]) - self.http_client.get_file.side_effect = mock_get_file - hs = self.setup_test_homeserver( - proxied_blocklisted_http_client=self.http_client - ) - return hs + servlets = [login.register_servlets] def default_config(self) -> Dict[str, Any]: config = super().default_config() @@ -1453,11 +1432,7 @@ class UsernamePickerTestCase(HomeserverTestCase): config["oidc_config"] = {} config["oidc_config"].update(TEST_OIDC_CONFIG) config["oidc_config"]["user_mapping_provider"] = { - "config": { - "display_name_template": "{{ user.displayname }}", - "email_template": "{{ user.email }}", - "picture_template": "{{ user.picture }}", - } + "config": {"display_name_template": "{{ user.displayname }}"} } # whitelist this client URI so we redirect straight to it rather than @@ -1470,22 +1445,15 @@ class UsernamePickerTestCase(HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def proceed_to_username_picker_page( - self, - fake_oidc_server: FakeOidcServer, - displayname: str, - email: str, - picture: str, - ) -> Tuple[str, str]: + def test_username_picker(self) -> None: + """Test the happy path of a username picker flow.""" + + fake_oidc_server = self.helper.fake_oidc_server() + # do the start of the login flow channel, _ = self.helper.auth_via_oidc( fake_oidc_server, - { - "sub": "tester", - "displayname": displayname, - "picture": picture, - "email": email, - }, + {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL, ) @@ -1512,132 +1480,16 @@ class UsernamePickerTestCase(HomeserverTestCase): ) session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") - self.assertEqual(session.display_name, displayname) - self.assertEqual(session.emails, [email]) - self.assertEqual(session.avatar_url, picture) + self.assertEqual(session.display_name, "Jonny") self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) # the expiry time should be about 15 minutes away expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) - return picker_url, session_id - - def test_username_picker_use_displayname_avatar_and_email(self) -> None: - """Test the happy path of a username picker flow with using displayname, avatar and email.""" - - fake_oidc_server = self.helper.fake_oidc_server() - - mxid = "@bobby:test" - displayname = "Jonny" - email = "bobby@test.com" - picture = "mxc://test/avatar_url" - - picker_url, session_id = self.proceed_to_username_picker_page( - fake_oidc_server, displayname, email, picture - ) - - # Now, submit a username to the username picker, which should serve a redirect - # to the completion page. - # Also specify that we should use the provided displayname, avatar and email. - content = urlencode( - { - b"username": b"bobby", - b"use_display_name": b"true", - b"use_avatar": b"true", - b"use_email": email, - } - ).encode("utf8") - chan = self.make_request( - "POST", - path=picker_url, - content=content, - content_is_form=True, - custom_headers=[ - ("Cookie", "username_mapping_session=" + session_id), - # old versions of twisted don't do form-parsing without a valid - # content-length header. - ("Content-Length", str(len(content))), - ], - ) - self.assertEqual(chan.code, 302, chan.result) - location_headers = chan.headers.getRawHeaders("Location") - assert location_headers - - # send a request to the completion page, which should 302 to the client redirectUrl - chan = self.make_request( - "GET", - path=location_headers[0], - custom_headers=[("Cookie", "username_mapping_session=" + session_id)], - ) - self.assertEqual(chan.code, 302, chan.result) - location_headers = chan.headers.getRawHeaders("Location") - assert location_headers - - # ensure that the returned location matches the requested redirect URL - path, query = location_headers[0].split("?", 1) - self.assertEqual(path, "https://x") - - # it will have url-encoded the params properly, so we'll have to parse them - params = urllib.parse.parse_qsl( - query, keep_blank_values=True, strict_parsing=True, errors="strict" - ) - self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) - self.assertEqual(params[2][0], "loginToken") - - # fish the login token out of the returned redirect uri - login_token = params[2][1] - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token, mxid, and device id. - chan = self.make_request( - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], mxid) - - # ensure the displayname and avatar from the OIDC response have been configured for the user. - channel = self.make_request( - "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"] - ) - self.assertEqual(channel.code, 200, channel.result) - self.assertIn("mxc://test", channel.json_body["avatar_url"]) - self.assertEqual(displayname, channel.json_body["displayname"]) - - # ensure the email from the OIDC response has been configured for the user. - channel = self.make_request( - "GET", "/account/3pid", access_token=chan.json_body["access_token"] - ) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(email, channel.json_body["threepids"][0]["address"]) - - def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None: - """Test the happy path of a username picker flow without using displayname, avatar or email.""" - - fake_oidc_server = self.helper.fake_oidc_server() - - mxid = "@bobby:test" - displayname = "Jonny" - email = "bobby@test.com" - picture = "mxc://test/avatar_url" - username = "bobby" - - picker_url, session_id = self.proceed_to_username_picker_page( - fake_oidc_server, displayname, email, picture - ) - # Now, submit a username to the username picker, which should serve a redirect - # to the completion page. - # Also specify that we should not use the provided displayname, avatar or email. - content = urlencode( - { - b"username": username, - b"use_display_name": b"false", - b"use_avatar": b"false", - } - ).encode("utf8") + # to the completion page + content = urlencode({b"username": b"bobby"}).encode("utf8") chan = self.make_request( "POST", path=picker_url, @@ -1686,29 +1538,4 @@ class UsernamePickerTestCase(HomeserverTestCase): content={"type": "m.login.token", "token": login_token}, ) self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], mxid) - - # ensure the displayname and avatar from the OIDC response have not been configured for the user. - channel = self.make_request( - "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"] - ) - self.assertEqual(channel.code, 200, channel.result) - self.assertNotIn("avatar_url", channel.json_body) - self.assertEqual(username, channel.json_body["displayname"]) - - # ensure the email from the OIDC response has not been configured for the user. - channel = self.make_request( - "GET", "/account/3pid", access_token=chan.json_body["access_token"] - ) - self.assertEqual(channel.code, 200, channel.result) - self.assertListEqual([], channel.json_body["threepids"]) - - -async def mock_get_file( - url: str, - output_stream: BinaryIO, - max_size: Optional[int] = None, - headers: Optional[RawHeaders] = None, - is_allowed_content_type: Optional[Callable[[str], bool]] = None, -) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: - return 0, {b"Content-Type": [b"image/png"]}, "", 200 + self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index fbacf9d869..2913679c8b 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py deleted file mode 100644
index 30b6d31d0a..0000000000 --- a/tests/rest/client/test_media.py +++ /dev/null
@@ -1,2677 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2022 The Matrix.org Foundation C.I.C. -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# -import base64 -import io -import json -import os -import re -import shutil -from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type -from unittest.mock import MagicMock, Mock, patch -from urllib import parse -from urllib.parse import quote, urlencode - -from parameterized import parameterized, parameterized_class -from PIL import Image as Image -from typing_extensions import ClassVar - -from twisted.internet import defer -from twisted.internet._resolver import HostResolution -from twisted.internet.address import IPv4Address, IPv6Address -from twisted.internet.defer import Deferred -from twisted.internet.error import DNSLookupError -from twisted.internet.interfaces import IAddress, IResolutionReceiver -from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor -from twisted.web.http_headers import Headers -from twisted.web.iweb import UNKNOWN_LENGTH, IResponse -from twisted.web.resource import Resource - -from synapse.api.errors import HttpResponseException -from synapse.api.ratelimiting import Ratelimiter -from synapse.config.oembed import OEmbedEndpointConfig -from synapse.http.client import MultipartResponse -from synapse.http.types import QueryParams -from synapse.logging.context import make_deferred_yieldable -from synapse.media._base import FileInfo, ThumbnailInfo -from synapse.media.thumbnailer import ThumbnailProvider -from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS -from synapse.rest import admin -from synapse.rest.client import login, media -from synapse.server import HomeServer -from synapse.types import JsonDict, UserID -from synapse.util import Clock -from synapse.util.stringutils import parse_and_validate_mxc_uri - -from tests import unittest -from tests.media.test_media_storage import ( - SVG, - TestImage, - empty_file, - small_lossless_webp, - small_png, - small_png_with_transparency, -) -from tests.server import FakeChannel, FakeTransport, ThreadedMemoryReactorClock -from tests.test_utils import SMALL_PNG -from tests.unittest import override_config - -try: - import lxml -except ImportError: - lxml = None # type: ignore[assignment] - - -class MediaDomainBlockingTests(unittest.HomeserverTestCase): - remote_media_id = "doesnotmatter" - remote_server_name = "evil.com" - servlets = [ - media.register_servlets, - admin.register_servlets, - login.register_servlets, - ] - - def make_homeserver( - self, reactor: ThreadedMemoryReactorClock, clock: Clock - ) -> HomeServer: - config = self.default_config() - - self.storage_path = self.mktemp() - self.media_store_path = self.mktemp() - os.mkdir(self.storage_path) - os.mkdir(self.media_store_path) - config["media_store_path"] = self.media_store_path - - provider_config = { - "module": "synapse.media.storage_provider.FileStorageProviderBackend", - "store_local": True, - "store_synchronous": False, - "store_remote": True, - "config": {"directory": self.storage_path}, - } - - config["media_storage_providers"] = [provider_config] - - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - - # Inject a piece of media. We'll use this to ensure we're returning a sane - # response when we're not supposed to block it, distinguishing a media block - # from a regular 404. - file_id = "abcdefg12345" - file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) - - media_storage = hs.get_media_repository().media_storage - - ctx = media_storage.store_into_file(file_info) - (f, fname) = self.get_success(ctx.__aenter__()) - f.write(SMALL_PNG) - self.get_success(ctx.__aexit__(None, None, None)) - - self.get_success( - self.store.store_cached_remote_media( - origin=self.remote_server_name, - media_id=self.remote_media_id, - media_type="image/png", - media_length=1, - time_now_ms=clock.time_msec(), - upload_name="test.png", - filesystem_id=file_id, - ) - ) - self.register_user("user", "password") - self.tok = self.login("user", "password") - - @override_config( - { - # Disable downloads from the domain we'll be trying to download from. - # Should result in a 404. - "prevent_media_downloads_from": ["evil.com"], - "dynamic_thumbnails": True, - } - ) - def test_cannot_download_blocked_media_thumbnail(self) -> None: - """ - Same test as test_cannot_download_blocked_media but for thumbnails. - """ - response = self.make_request( - "GET", - f"/_matrix/client/v1/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100", - shorthand=False, - content={"width": 100, "height": 100}, - access_token=self.tok, - ) - self.assertEqual(response.code, 404) - - @override_config( - { - # Disable downloads from a domain we won't be requesting downloads from. - # This proves we haven't broken anything. - "prevent_media_downloads_from": ["not-listed.com"], - "dynamic_thumbnails": True, - } - ) - def test_remote_media_thumbnail_normally_unblocked(self) -> None: - """ - Same test as test_remote_media_normally_unblocked but for thumbnails. - """ - response = self.make_request( - "GET", - f"/_matrix/client/v1/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100", - shorthand=False, - access_token=self.tok, - ) - self.assertEqual(response.code, 200) - - -class URLPreviewTests(unittest.HomeserverTestCase): - if not lxml: - skip = "url preview feature requires lxml" - - servlets = [media.register_servlets] - hijack_auth = True - user_id = "@test:user" - end_content = ( - b"<html><head>" - b'<meta property="og:title" content="~matrix~" />' - b'<meta property="og:description" content="hi" />' - b"</head></html>" - ) - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - config["url_preview_enabled"] = True - config["max_spider_size"] = 9999999 - config["url_preview_ip_range_blacklist"] = ( - "192.168.1.1", - "1.0.0.0/8", - "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", - "2001:800::/21", - ) - config["url_preview_ip_range_whitelist"] = ("1.1.1.1",) - config["url_preview_accept_language"] = [ - "en-UK", - "en-US;q=0.9", - "fr;q=0.8", - "*;q=0.7", - ] - - self.storage_path = self.mktemp() - self.media_store_path = self.mktemp() - os.mkdir(self.storage_path) - os.mkdir(self.media_store_path) - config["media_store_path"] = self.media_store_path - - provider_config = { - "module": "synapse.media.storage_provider.FileStorageProviderBackend", - "store_local": True, - "store_synchronous": False, - "store_remote": True, - "config": {"directory": self.storage_path}, - } - - config["media_storage_providers"] = [provider_config] - - hs = self.setup_test_homeserver(config=config) - - # After the hs is created, modify the parsed oEmbed config (to avoid - # messing with files). - # - # Note that HTTP URLs are used to avoid having to deal with TLS in tests. - hs.config.oembed.oembed_patterns = [ - OEmbedEndpointConfig( - api_endpoint="http://publish.twitter.com/oembed", - url_patterns=[ - re.compile(r"http://twitter\.com/.+/status/.+"), - ], - formats=None, - ), - OEmbedEndpointConfig( - api_endpoint="http://www.hulu.com/api/oembed.{format}", - url_patterns=[ - re.compile(r"http://www\.hulu\.com/watch/.+"), - ], - formats=["json"], - ), - ] - - return hs - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository() - assert self.media_repo.url_previewer is not None - self.url_previewer = self.media_repo.url_previewer - - self.lookups: Dict[str, Any] = {} - - class Resolver: - def resolveHostName( - _self, - resolutionReceiver: IResolutionReceiver, - hostName: str, - portNumber: int = 0, - addressTypes: Optional[Sequence[Type[IAddress]]] = None, - transportSemantics: str = "TCP", - ) -> IResolutionReceiver: - resolution = HostResolution(hostName) - resolutionReceiver.resolutionBegan(resolution) - if hostName not in self.lookups: - raise DNSLookupError("OH NO") - - for i in self.lookups[hostName]: - resolutionReceiver.addressResolved(i[0]("TCP", i[1], portNumber)) - resolutionReceiver.resolutionComplete() - return resolutionReceiver - - self.reactor.nameResolver = Resolver() # type: ignore[assignment] - - def _assert_small_png(self, json_body: JsonDict) -> None: - """Assert properties from the SMALL_PNG test image.""" - self.assertTrue(json_body["og:image"].startswith("mxc://")) - self.assertEqual(json_body["og:image:height"], 1) - self.assertEqual(json_body["og:image:width"], 1) - self.assertEqual(json_body["og:image:type"], "image/png") - self.assertEqual(json_body["matrix:image:size"], 67) - - def test_cache_returns_correct_type(self) -> None: - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" - % (len(self.end_content),) - + self.end_content - ) - - self.pump() - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} - ) - - # Check the cache returns the correct response - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - ) - - # Check the cache response has the same content - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} - ) - - # Clear the in-memory cache - self.assertIn("http://matrix.org", self.url_previewer._cache) - self.url_previewer._cache.pop("http://matrix.org") - self.assertNotIn("http://matrix.org", self.url_previewer._cache) - - # Check the database cache returns the correct response - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - ) - - # Check the cache response has the same content - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} - ) - - def test_non_ascii_preview_httpequiv(self) -> None: - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - - end_content = ( - b"<html><head>" - b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>' - b'<meta property="og:title" content="\xe4\xea\xe0" />' - b'<meta property="og:description" content="hi" />' - b"</head></html>" - ) - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="utf8"\r\n\r\n' - ) - % (len(end_content),) - + end_content - ) - - self.pump() - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") - - def test_video_rejected(self) -> None: - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - - end_content = b"anything" - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b"Content-Type: video/mp4\r\n\r\n" - ) - % (len(end_content)) - + end_content - ) - - self.pump() - self.assertEqual(channel.code, 502) - self.assertEqual( - channel.json_body, - { - "errcode": "M_UNKNOWN", - "error": "Requested file's content type not allowed for this operation: video/mp4", - }, - ) - - def test_audio_rejected(self) -> None: - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - - end_content = b"anything" - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b"Content-Type: audio/aac\r\n\r\n" - ) - % (len(end_content)) - + end_content - ) - - self.pump() - self.assertEqual(channel.code, 502) - self.assertEqual( - channel.json_body, - { - "errcode": "M_UNKNOWN", - "error": "Requested file's content type not allowed for this operation: audio/aac", - }, - ) - - def test_non_ascii_preview_content_type(self) -> None: - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - - end_content = ( - b"<html><head>" - b'<meta property="og:title" content="\xe4\xea\xe0" />' - b'<meta property="og:description" content="hi" />' - b"</head></html>" - ) - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="windows-1251"\r\n\r\n' - ) - % (len(end_content),) - + end_content - ) - - self.pump() - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") - - def test_overlong_title(self) -> None: - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - - end_content = ( - b"<html><head>" - b"<title>" + b"x" * 2000 + b"</title>" - b'<meta property="og:description" content="hi" />' - b"</head></html>" - ) - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="windows-1251"\r\n\r\n' - ) - % (len(end_content),) - + end_content - ) - - self.pump() - self.assertEqual(channel.code, 200) - res = channel.json_body - # We should only see the `og:description` field, as `title` is too long and should be stripped out - self.assertCountEqual(["og:description"], res.keys()) - - def test_ipaddr(self) -> None: - """ - IP addresses can be previewed directly. - """ - self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://example.com", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" - % (len(self.end_content),) - + self.end_content - ) - - self.pump() - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} - ) - - def test_blocked_ip_specific(self) -> None: - """ - Blocked IP addresses, found via DNS, are not spidered. - """ - self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://example.com", - shorthand=False, - ) - - # No requests made. - self.assertEqual(len(self.reactor.tcpClients), 0) - self.assertEqual(channel.code, 502) - self.assertEqual( - channel.json_body, - { - "errcode": "M_UNKNOWN", - "error": "DNS resolution failure during URL preview generation", - }, - ) - - def test_blocked_ip_range(self) -> None: - """ - Blocked IP ranges, IPs found over DNS, are not spidered. - """ - self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://example.com", - shorthand=False, - ) - - self.assertEqual(channel.code, 502) - self.assertEqual( - channel.json_body, - { - "errcode": "M_UNKNOWN", - "error": "DNS resolution failure during URL preview generation", - }, - ) - - def test_blocked_ip_specific_direct(self) -> None: - """ - Blocked IP addresses, accessed directly, are not spidered. - """ - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://192.168.1.1", - shorthand=False, - ) - - # No requests made. - self.assertEqual(len(self.reactor.tcpClients), 0) - self.assertEqual( - channel.json_body, - {"errcode": "M_UNKNOWN", "error": "IP address blocked"}, - ) - self.assertEqual(channel.code, 403) - - def test_blocked_ip_range_direct(self) -> None: - """ - Blocked IP ranges, accessed directly, are not spidered. - """ - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://1.1.1.2", - shorthand=False, - ) - - self.assertEqual(channel.code, 403) - self.assertEqual( - channel.json_body, - {"errcode": "M_UNKNOWN", "error": "IP address blocked"}, - ) - - def test_blocked_ip_range_whitelisted_ip(self) -> None: - """ - Blocked but then subsequently whitelisted IP addresses can be - spidered. - """ - self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")] - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://example.com", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - - client.dataReceived( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" - % (len(self.end_content),) - + self.end_content - ) - - self.pump() - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} - ) - - def test_blocked_ip_with_external_ip(self) -> None: - """ - If a hostname resolves a blocked IP, even if there's a non-blocked one, - it will be rejected. - """ - # Hardcode the URL resolving to the IP we want. - self.lookups["example.com"] = [ - (IPv4Address, "1.1.1.2"), - (IPv4Address, "10.1.2.3"), - ] - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://example.com", - shorthand=False, - ) - self.assertEqual(channel.code, 502) - self.assertEqual( - channel.json_body, - { - "errcode": "M_UNKNOWN", - "error": "DNS resolution failure during URL preview generation", - }, - ) - - def test_blocked_ipv6_specific(self) -> None: - """ - Blocked IP addresses, found via DNS, are not spidered. - """ - self.lookups["example.com"] = [ - (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") - ] - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://example.com", - shorthand=False, - ) - - # No requests made. - self.assertEqual(len(self.reactor.tcpClients), 0) - self.assertEqual(channel.code, 502) - self.assertEqual( - channel.json_body, - { - "errcode": "M_UNKNOWN", - "error": "DNS resolution failure during URL preview generation", - }, - ) - - def test_blocked_ipv6_range(self) -> None: - """ - Blocked IP ranges, IPs found over DNS, are not spidered. - """ - self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://example.com", - shorthand=False, - ) - - self.assertEqual(channel.code, 502) - self.assertEqual( - channel.json_body, - { - "errcode": "M_UNKNOWN", - "error": "DNS resolution failure during URL preview generation", - }, - ) - - def test_OPTIONS(self) -> None: - """ - OPTIONS returns the OPTIONS. - """ - channel = self.make_request( - "OPTIONS", - "/_matrix/client/v1/media/preview_url?url=http://example.com", - shorthand=False, - ) - self.assertEqual(channel.code, 204) - - def test_accept_language_config_option(self) -> None: - """ - Accept-Language header is sent to the remote server - """ - self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] - - # Build and make a request to the server - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://example.com", - shorthand=False, - await_result=False, - ) - self.pump() - - # Extract Synapse's tcp client - client = self.reactor.tcpClients[0][2].buildProtocol(None) - - # Build a fake remote server to reply with - server = AccumulatingProtocol() - - # Connect the two together - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - - # Tell Synapse that it has received some data from the remote server - client.dataReceived( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" - % (len(self.end_content),) - + self.end_content - ) - - # Move the reactor along until we get a response on our original channel - self.pump() - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} - ) - - # Check that the server received the Accept-Language header as part - # of the request from Synapse - self.assertIn( - ( - b"Accept-Language: en-UK\r\n" - b"Accept-Language: en-US;q=0.9\r\n" - b"Accept-Language: fr;q=0.8\r\n" - b"Accept-Language: *;q=0.7" - ), - server.data, - ) - - def test_image(self) -> None: - """An image should be precached if mentioned in the HTML.""" - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")] - - result = ( - b"""<html><body><img src="http://cdn.matrix.org/foo.png"></body></html>""" - ) - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - await_result=False, - ) - self.pump() - - # Respond with the HTML. - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="utf8"\r\n\r\n' - ) - % (len(result),) - + result - ) - self.pump() - - # Respond with the photo. - client = self.reactor.tcpClients[1][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b"Content-Type: image/png\r\n\r\n" - ) - % (len(SMALL_PNG),) - + SMALL_PNG - ) - self.pump() - - # The image should be in the result. - self.assertEqual(channel.code, 200) - self._assert_small_png(channel.json_body) - - def test_nonexistent_image(self) -> None: - """If the preview image doesn't exist, ensure some data is returned.""" - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - - result = ( - b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>""" - ) - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="utf8"\r\n\r\n' - ) - % (len(result),) - + result - ) - - self.pump() - - # There should not be a second connection. - self.assertEqual(len(self.reactor.tcpClients), 1) - - # The image should not be in the result. - self.assertEqual(channel.code, 200) - self.assertNotIn("og:image", channel.json_body) - - @unittest.override_config( - {"url_preview_url_blacklist": [{"netloc": "cdn.matrix.org"}]} - ) - def test_image_blocked(self) -> None: - """If the preview image doesn't exist, ensure some data is returned.""" - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")] - - result = ( - b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>""" - ) - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="utf8"\r\n\r\n' - ) - % (len(result),) - + result - ) - self.pump() - - # There should not be a second connection. - self.assertEqual(len(self.reactor.tcpClients), 1) - - # The image should not be in the result. - self.assertEqual(channel.code, 200) - self.assertNotIn("og:image", channel.json_body) - - def test_oembed_failure(self) -> None: - """If the autodiscovered oEmbed URL fails, ensure some data is returned.""" - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - - result = b""" - <title>oEmbed Autodiscovery Fail</title> - <link rel="alternate" type="application/json+oembed" - href="http://example.com/oembed?url=http%3A%2F%2Fmatrix.org&format=json" - title="matrixdotorg" /> - """ - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="utf8"\r\n\r\n' - ) - % (len(result),) - + result - ) - - self.pump() - self.assertEqual(channel.code, 200) - - # The image should not be in the result. - self.assertEqual(channel.json_body["og:title"], "oEmbed Autodiscovery Fail") - - def test_data_url(self) -> None: - """ - Requesting to preview a data URL is not supported. - """ - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - - data = base64.b64encode(SMALL_PNG).decode() - - query_params = urlencode( - { - "url": f'<html><head><img src="data:image/png;base64,{data}" /></head></html>' - } - ) - - channel = self.make_request( - "GET", - f"/_matrix/client/v1/media/preview_url?{query_params}", - shorthand=False, - ) - self.pump() - - self.assertEqual(channel.code, 500) - - def test_inline_data_url(self) -> None: - """ - An inline image (as a data URL) should be parsed properly. - """ - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - - data = base64.b64encode(SMALL_PNG) - - end_content = ( - b"<html><head>" b'<img src="data:image/png;base64,%s" />' b"</head></html>" - ) % (data,) - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://matrix.org", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="utf8"\r\n\r\n' - ) - % (len(end_content),) - + end_content - ) - - self.pump() - self.assertEqual(channel.code, 200) - self._assert_small_png(channel.json_body) - - def test_oembed_photo(self) -> None: - """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" - self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] - self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] - - result = { - "version": "1.0", - "type": "photo", - "url": "http://cdn.twitter.com/matrixdotorg", - } - oembed_content = json.dumps(result).encode("utf-8") - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: application/json; charset="utf8"\r\n\r\n' - ) - % (len(oembed_content),) - + oembed_content - ) - - self.pump() - - # Ensure a second request is made to the photo URL. - client = self.reactor.tcpClients[1][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b"Content-Type: image/png\r\n\r\n" - ) - % (len(SMALL_PNG),) - + SMALL_PNG - ) - - self.pump() - - # Ensure the URL is what was requested. - self.assertIn(b"/matrixdotorg", server.data) - - self.assertEqual(channel.code, 200) - body = channel.json_body - self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345") - self._assert_small_png(body) - - def test_oembed_rich(self) -> None: - """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" - self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] - - result = { - "version": "1.0", - "type": "rich", - # Note that this provides the author, not the title. - "author_name": "Alice", - "html": "<div>Content Preview</div>", - } - end_content = json.dumps(result).encode("utf-8") - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: application/json; charset="utf8"\r\n\r\n' - ) - % (len(end_content),) - + end_content - ) - - self.pump() - - # Double check that the proper host is being connected to. (Note that - # twitter.com can't be resolved so this is already implicitly checked.) - self.assertIn(b"\r\nHost: publish.twitter.com\r\n", server.data) - - self.assertEqual(channel.code, 200) - body = channel.json_body - self.assertEqual( - body, - { - "og:url": "http://twitter.com/matrixdotorg/status/12345", - "og:title": "Alice", - "og:description": "Content Preview", - }, - ) - - def test_oembed_format(self) -> None: - """Test an oEmbed endpoint which requires the format in the URL.""" - self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")] - - result = { - "version": "1.0", - "type": "rich", - "html": "<div>Content Preview</div>", - } - end_content = json.dumps(result).encode("utf-8") - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://www.hulu.com/watch/12345", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: application/json; charset="utf8"\r\n\r\n' - ) - % (len(end_content),) - + end_content - ) - - self.pump() - - # The {format} should have been turned into json. - self.assertIn(b"/api/oembed.json", server.data) - # A URL parameter of format=json should be provided. - self.assertIn(b"format=json", server.data) - - self.assertEqual(channel.code, 200) - body = channel.json_body - self.assertEqual( - body, - { - "og:url": "http://www.hulu.com/watch/12345", - "og:description": "Content Preview", - }, - ) - - @unittest.override_config( - {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]} - ) - def test_oembed_blocked(self) -> None: - """The oEmbed URL should not be downloaded if the oEmbed URL is blocked.""" - self.lookups["twitter.com"] = [(IPv4Address, "10.1.2.3")] - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345", - shorthand=False, - await_result=False, - ) - self.pump() - self.assertEqual(channel.code, 403, channel.result) - - def test_oembed_autodiscovery(self) -> None: - """ - Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL. - 1. Request a preview of a URL which is not known to the oEmbed code. - 2. It returns HTML including a link to an oEmbed preview. - 3. The oEmbed preview is requested and returns a URL for an image. - 4. The image is requested for thumbnailing. - """ - # This is a little cheesy in that we use the www subdomain (which isn't the - # list of oEmbed patterns) to get "raw" HTML response. - self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")] - self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] - self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] - - result = b""" - <link rel="alternate" type="application/json+oembed" - href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json" - title="matrixdotorg" /> - """ - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="utf8"\r\n\r\n' - ) - % (len(result),) - + result - ) - self.pump() - - # The oEmbed response. - result2 = { - "version": "1.0", - "type": "photo", - "url": "http://cdn.twitter.com/matrixdotorg", - } - oembed_content = json.dumps(result2).encode("utf-8") - - # Ensure a second request is made to the oEmbed URL. - client = self.reactor.tcpClients[1][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: application/json; charset="utf8"\r\n\r\n' - ) - % (len(oembed_content),) - + oembed_content - ) - self.pump() - - # Ensure the URL is what was requested. - self.assertIn(b"/oembed?", server.data) - - # Ensure a third request is made to the photo URL. - client = self.reactor.tcpClients[2][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b"Content-Type: image/png\r\n\r\n" - ) - % (len(SMALL_PNG),) - + SMALL_PNG - ) - self.pump() - - # Ensure the URL is what was requested. - self.assertIn(b"/matrixdotorg", server.data) - - self.assertEqual(channel.code, 200) - body = channel.json_body - self.assertEqual( - body["og:url"], "http://www.twitter.com/matrixdotorg/status/12345" - ) - self._assert_small_png(body) - - @unittest.override_config( - {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]} - ) - def test_oembed_autodiscovery_blocked(self) -> None: - """ - If the discovered oEmbed URL is blocked, it should be discarded. - """ - # This is a little cheesy in that we use the www subdomain (which isn't the - # list of oEmbed patterns) to get "raw" HTML response. - self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")] - self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.4")] - - result = b""" - <title>Test</title> - <link rel="alternate" type="application/json+oembed" - href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json" - title="matrixdotorg" /> - """ - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="utf8"\r\n\r\n' - ) - % (len(result),) - + result - ) - - self.pump() - - # Ensure there's no additional connections. - self.assertEqual(len(self.reactor.tcpClients), 1) - - # Ensure the URL is what was requested. - self.assertIn(b"\r\nHost: www.twitter.com\r\n", server.data) - - self.assertEqual(channel.code, 200) - body = channel.json_body - self.assertEqual(body["og:title"], "Test") - self.assertNotIn("og:image", body) - - def _download_image(self) -> Tuple[str, str]: - """Downloads an image into the URL cache. - Returns: - A (host, media_id) tuple representing the MXC URI of the image. - """ - self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=http://cdn.twitter.com/matrixdotorg", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: image/png\r\n\r\n" - % (len(SMALL_PNG),) - + SMALL_PNG - ) - - self.pump() - self.assertEqual(channel.code, 200) - body = channel.json_body - mxc_uri = body["og:image"] - host, _port, media_id = parse_and_validate_mxc_uri(mxc_uri) - self.assertIsNone(_port) - return host, media_id - - def test_storage_providers_exclude_files(self) -> None: - """Test that files are not stored in or fetched from storage providers.""" - host, media_id = self._download_image() - - rel_file_path = self.media_repo.filepaths.url_cache_filepath_rel(media_id) - media_store_path = os.path.join(self.media_store_path, rel_file_path) - storage_provider_path = os.path.join(self.storage_path, rel_file_path) - - # Check storage - self.assertTrue(os.path.isfile(media_store_path)) - self.assertFalse( - os.path.isfile(storage_provider_path), - "URL cache file was unexpectedly stored in a storage provider", - ) - - # Check fetching - channel = self.make_request( - "GET", - f"/_matrix/client/v1/media/download/{host}/{media_id}", - shorthand=False, - await_result=False, - ) - self.pump() - self.assertEqual(channel.code, 200) - - # Move cached file into the storage provider - os.makedirs(os.path.dirname(storage_provider_path), exist_ok=True) - os.rename(media_store_path, storage_provider_path) - - channel = self.make_request( - "GET", - f"/_matrix/client/v1/download/{host}/{media_id}", - shorthand=False, - await_result=False, - ) - self.pump() - self.assertEqual( - channel.code, - 404, - "URL cache file was unexpectedly retrieved from a storage provider", - ) - - def test_storage_providers_exclude_thumbnails(self) -> None: - """Test that thumbnails are not stored in or fetched from storage providers.""" - host, media_id = self._download_image() - - rel_thumbnail_path = ( - self.media_repo.filepaths.url_cache_thumbnail_directory_rel(media_id) - ) - media_store_thumbnail_path = os.path.join( - self.media_store_path, rel_thumbnail_path - ) - storage_provider_thumbnail_path = os.path.join( - self.storage_path, rel_thumbnail_path - ) - - # Check storage - self.assertTrue(os.path.isdir(media_store_thumbnail_path)) - self.assertFalse( - os.path.isdir(storage_provider_thumbnail_path), - "URL cache thumbnails were unexpectedly stored in a storage provider", - ) - - # Check fetching - channel = self.make_request( - "GET", - f"/_matrix/client/v1/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale", - shorthand=False, - await_result=False, - ) - self.pump() - self.assertEqual(channel.code, 200) - - # Remove the original, otherwise thumbnails will regenerate - rel_file_path = self.media_repo.filepaths.url_cache_filepath_rel(media_id) - media_store_path = os.path.join(self.media_store_path, rel_file_path) - os.remove(media_store_path) - - # Move cached thumbnails into the storage provider - os.makedirs(os.path.dirname(storage_provider_thumbnail_path), exist_ok=True) - os.rename(media_store_thumbnail_path, storage_provider_thumbnail_path) - - channel = self.make_request( - "GET", - f"/_matrix/client/v1/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale", - shorthand=False, - await_result=False, - ) - self.pump() - self.assertEqual( - channel.code, - 404, - "URL cache thumbnail was unexpectedly retrieved from a storage provider", - ) - - def test_cache_expiry(self) -> None: - """Test that URL cache files and thumbnails are cleaned up properly on expiry.""" - _host, media_id = self._download_image() - - file_path = self.media_repo.filepaths.url_cache_filepath(media_id) - file_dirs = self.media_repo.filepaths.url_cache_filepath_dirs_to_delete( - media_id - ) - thumbnail_dir = self.media_repo.filepaths.url_cache_thumbnail_directory( - media_id - ) - thumbnail_dirs = self.media_repo.filepaths.url_cache_thumbnail_dirs_to_delete( - media_id - ) - - self.assertTrue(os.path.isfile(file_path)) - self.assertTrue(os.path.isdir(thumbnail_dir)) - - self.reactor.advance(IMAGE_CACHE_EXPIRY_MS * 1000 + 1) - self.get_success(self.url_previewer._expire_url_cache_data()) - - for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs: - self.assertFalse( - os.path.exists(path), - f"{os.path.relpath(path, self.media_store_path)} was not deleted", - ) - - @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]}) - def test_blocked_port(self) -> None: - """Tests that blocking URLs with a port makes previewing such URLs - fail with a 403 error and doesn't impact other previews. - """ - self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - - bad_url = quote("http://matrix.org:8888/foo") - good_url = quote("http://matrix.org/foo") - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=" + bad_url, - shorthand=False, - await_result=False, - ) - self.pump() - self.assertEqual(channel.code, 403, channel.result) - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=" + good_url, - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" - % (len(self.end_content),) - + self.end_content - ) - - self.pump() - self.assertEqual(channel.code, 200) - - @unittest.override_config( - {"url_preview_url_blacklist": [{"netloc": "example.com"}]} - ) - def test_blocked_url(self) -> None: - """Tests that blocking URLs with a host makes previewing such URLs - fail with a 403 error. - """ - self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")] - - bad_url = quote("http://example.com/foo") - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/preview_url?url=" + bad_url, - shorthand=False, - await_result=False, - ) - self.pump() - self.assertEqual(channel.code, 403, channel.result) - - -class MediaConfigTest(unittest.HomeserverTestCase): - servlets = [ - media.register_servlets, - admin.register_servlets, - login.register_servlets, - ] - - def make_homeserver( - self, reactor: ThreadedMemoryReactorClock, clock: Clock - ) -> HomeServer: - config = self.default_config() - - self.storage_path = self.mktemp() - self.media_store_path = self.mktemp() - os.mkdir(self.storage_path) - os.mkdir(self.media_store_path) - config["media_store_path"] = self.media_store_path - - provider_config = { - "module": "synapse.media.storage_provider.FileStorageProviderBackend", - "store_local": True, - "store_synchronous": False, - "store_remote": True, - "config": {"directory": self.storage_path}, - } - - config["media_storage_providers"] = [provider_config] - - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.register_user("user", "password") - self.tok = self.login("user", "password") - - def test_media_config(self) -> None: - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/config", - shorthand=False, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body["m.upload.size"], self.hs.config.media.max_upload_size - ) - - -class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): - servlets = [ - media.register_servlets, - login.register_servlets, - admin.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - - self.storage_path = self.mktemp() - self.media_store_path = self.mktemp() - os.mkdir(self.storage_path) - os.mkdir(self.media_store_path) - config["media_store_path"] = self.media_store_path - - provider_config = { - "module": "synapse.media.storage_provider.FileStorageProviderBackend", - "store_local": True, - "store_synchronous": False, - "store_remote": True, - "config": {"directory": self.storage_path}, - } - - config["media_storage_providers"] = [provider_config] - - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.repo = hs.get_media_repository() - self.client = hs.get_federation_http_client() - self.store = hs.get_datastores().main - self.user = self.register_user("user", "pass") - self.tok = self.login("user", "pass") - - # mock actually reading file body - def read_multipart_response_30MiB(*args: Any, **kwargs: Any) -> Deferred: - d: Deferred = defer.Deferred() - d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None)) - return d - - def read_multipart_response_50MiB(*args: Any, **kwargs: Any) -> Deferred: - d: Deferred = defer.Deferred() - d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None)) - return d - - @patch( - "synapse.http.matrixfederationclient.read_multipart_response", - read_multipart_response_30MiB, - ) - def test_download_ratelimit_default(self) -> None: - """ - Test remote media download ratelimiting against default configuration - 500MB bucket - and 87kb/second drain rate - """ - - # mock out actually sending the request, returns a 30MiB response - async def _send_request(*args: Any, **kwargs: Any) -> IResponse: - resp = MagicMock(spec=IResponse) - resp.code = 200 - resp.length = 31457280 - resp.headers = Headers( - {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]} - ) - resp.phrase = b"OK" - return resp - - self.client._send_request = _send_request # type: ignore - - # first request should go through - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/download/remote.org/abc", - shorthand=False, - access_token=self.tok, - ) - assert channel.code == 200 - - # next 15 should go through - for i in range(15): - channel2 = self.make_request( - "GET", - f"/_matrix/client/v1/media/download/remote.org/abc{i}", - shorthand=False, - access_token=self.tok, - ) - assert channel2.code == 200 - - # 17th will hit ratelimit - channel3 = self.make_request( - "GET", - "/_matrix/client/v1/media/download/remote.org/abcd", - shorthand=False, - access_token=self.tok, - ) - assert channel3.code == 429 - - # however, a request from a different IP will go through - channel4 = self.make_request( - "GET", - "/_matrix/client/v1/media/download/remote.org/abcde", - shorthand=False, - client_ip="187.233.230.159", - access_token=self.tok, - ) - assert channel4.code == 200 - - # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another - # 30MiB download is authorized - The last download was blocked at 503,316,480. - # The next download will be authorized when bucket hits 492,830,720 - # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760 - # needs to drain before another download will be authorized, that will take ~= - # 2 minutes (10,485,760/89,088/60) - self.reactor.pump([2.0 * 60.0]) - - # enough has drained and next request goes through - channel5 = self.make_request( - "GET", - "/_matrix/client/v1/media/download/remote.org/abcdef", - shorthand=False, - access_token=self.tok, - ) - assert channel5.code == 200 - - @override_config( - { - "remote_media_download_per_second": "50M", - "remote_media_download_burst_count": "50M", - } - ) - @patch( - "synapse.http.matrixfederationclient.read_multipart_response", - read_multipart_response_50MiB, - ) - def test_download_rate_limit_config(self) -> None: - """ - Test that download rate limit config options are correctly picked up and applied - """ - - async def _send_request(*args: Any, **kwargs: Any) -> IResponse: - resp = MagicMock(spec=IResponse) - resp.code = 200 - resp.length = 52428800 - resp.headers = Headers( - {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]} - ) - resp.phrase = b"OK" - return resp - - self.client._send_request = _send_request # type: ignore - - # first request should go through - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/download/remote.org/abc", - shorthand=False, - access_token=self.tok, - ) - assert channel.code == 200 - - # immediate second request should fail - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/download/remote.org/abcd", - shorthand=False, - access_token=self.tok, - ) - assert channel.code == 429 - - # advance half a second - self.reactor.pump([0.5]) - - # request still fails - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/download/remote.org/abcde", - shorthand=False, - access_token=self.tok, - ) - assert channel.code == 429 - - # advance another half second - self.reactor.pump([0.5]) - - # enough has drained from bucket and request is successful - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/download/remote.org/abcdef", - shorthand=False, - access_token=self.tok, - ) - assert channel.code == 200 - - @override_config( - { - "remote_media_download_burst_count": "87M", - } - ) - @patch( - "synapse.http.matrixfederationclient.read_multipart_response", - read_multipart_response_30MiB, - ) - def test_download_ratelimit_unknown_length(self) -> None: - """ - Test that if no content-length is provided, ratelimiting is still applied after - media is downloaded and length is known - """ - - # mock out actually sending the request - async def _send_request(*args: Any, **kwargs: Any) -> IResponse: - resp = MagicMock(spec=IResponse) - resp.code = 200 - resp.length = UNKNOWN_LENGTH - resp.headers = Headers( - {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]} - ) - resp.phrase = b"OK" - return resp - - self.client._send_request = _send_request # type: ignore - - # first 3 will go through (note that 3rd request technically violates rate limit but - # that since the ratelimiting is applied *after* download it goes through, but next one fails) - for i in range(3): - channel2 = self.make_request( - "GET", - f"/_matrix/client/v1/media/download/remote.org/abc{i}", - shorthand=False, - access_token=self.tok, - ) - assert channel2.code == 200 - - # 4th will hit ratelimit - channel3 = self.make_request( - "GET", - "/_matrix/client/v1/media/download/remote.org/abcd", - shorthand=False, - access_token=self.tok, - ) - assert channel3.code == 429 - - @override_config({"max_upload_size": "29M"}) - @patch( - "synapse.http.matrixfederationclient.read_multipart_response", - read_multipart_response_30MiB, - ) - def test_max_download_respected(self) -> None: - """ - Test that the max download size is enforced - note that max download size is determined - by the max_upload_size - """ - - # mock out actually sending the request, returns a 30MiB response - async def _send_request(*args: Any, **kwargs: Any) -> IResponse: - resp = MagicMock(spec=IResponse) - resp.code = 200 - resp.length = 31457280 - resp.headers = Headers( - {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]} - ) - resp.phrase = b"OK" - return resp - - self.client._send_request = _send_request # type: ignore - - channel = self.make_request( - "GET", - "/_matrix/client/v1/media/download/remote.org/abcd", - shorthand=False, - access_token=self.tok, - ) - assert channel.code == 502 - assert channel.json_body["errcode"] == "M_TOO_LARGE" - - def test_file_download(self) -> None: - content = io.BytesIO(b"file_to_stream") - content_uri = self.get_success( - self.repo.create_content( - "text/plain", - "test_upload", - content, - 46, - UserID.from_string("@user_id:whatever.org"), - ) - ) - # test with a text file - channel = self.make_request( - "GET", - f"/_matrix/client/v1/media/download/test/{content_uri.media_id}", - shorthand=False, - access_token=self.tok, - ) - self.pump() - self.assertEqual(200, channel.code) - - -test_images = [ - small_png, - small_png_with_transparency, - small_lossless_webp, - empty_file, - SVG, -] -input_values = [(x,) for x in test_images] - - -@parameterized_class(("test_image",), input_values) -class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): - test_image: ClassVar[TestImage] - servlets = [ - media.register_servlets, - login.register_servlets, - admin.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fetches: List[ - Tuple[ - "Deferred[Any]", - str, - str, - Optional[QueryParams], - ] - ] = [] - - def federation_get_file( - destination: str, - path: str, - output_stream: BinaryIO, - download_ratelimiter: Ratelimiter, - ip_address: Any, - max_size: int, - args: Optional[QueryParams] = None, - retry_on_dns_fail: bool = True, - ignore_backoff: bool = False, - follow_redirects: bool = False, - ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]], bytes]]": - """A mock for MatrixFederationHttpClient.federation_get_file.""" - - def write_to( - r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]] - ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]: - data, response = r - output_stream.write(data) - return response - - def write_err(f: Failure) -> Failure: - f.trap(HttpResponseException) - output_stream.write(f.value.response) - return f - - d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]] = ( - Deferred() - ) - self.fetches.append((d, destination, path, args)) - # Note that this callback changes the value held by d. - d_after_callback = d.addCallbacks(write_to, write_err) - return make_deferred_yieldable(d_after_callback) - - def get_file( - destination: str, - path: str, - output_stream: BinaryIO, - download_ratelimiter: Ratelimiter, - ip_address: Any, - max_size: int, - args: Optional[QueryParams] = None, - retry_on_dns_fail: bool = True, - ignore_backoff: bool = False, - follow_redirects: bool = False, - ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": - """A mock for MatrixFederationHttpClient.get_file.""" - - def write_to( - r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]] - ) -> Tuple[int, Dict[bytes, List[bytes]]]: - data, response = r - output_stream.write(data) - return response - - def write_err(f: Failure) -> Failure: - f.trap(HttpResponseException) - output_stream.write(f.value.response) - return f - - d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() - self.fetches.append((d, destination, path, args)) - # Note that this callback changes the value held by d. - d_after_callback = d.addCallbacks(write_to, write_err) - return make_deferred_yieldable(d_after_callback) - - # Mock out the homeserver's MatrixFederationHttpClient - client = Mock() - client.federation_get_file = federation_get_file - client.get_file = get_file - - self.storage_path = self.mktemp() - self.media_store_path = self.mktemp() - os.mkdir(self.storage_path) - os.mkdir(self.media_store_path) - - config = self.default_config() - config["media_store_path"] = self.media_store_path - config["max_image_pixels"] = 2000000 - - provider_config = { - "module": "synapse.media.storage_provider.FileStorageProviderBackend", - "store_local": True, - "store_synchronous": False, - "store_remote": True, - "config": {"directory": self.storage_path}, - } - config["media_storage_providers"] = [provider_config] - - hs = self.setup_test_homeserver(config=config, federation_http_client=client) - - return hs - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.media_repo = hs.get_media_repository() - - self.remote = "example.com" - self.media_id = "12345" - - self.user = self.register_user("user", "pass") - self.tok = self.login("user", "pass") - - def _req( - self, content_disposition: Optional[bytes], include_content_type: bool = True - ) -> FakeChannel: - channel = self.make_request( - "GET", - f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}", - shorthand=False, - await_result=False, - access_token=self.tok, - ) - self.pump() - - # We've made one fetch, to example.com, using the federation media URL - self.assertEqual(len(self.fetches), 1) - self.assertEqual(self.fetches[0][1], "example.com") - self.assertEqual( - self.fetches[0][2], "/_matrix/federation/v1/media/download/" + self.media_id - ) - self.assertEqual( - self.fetches[0][3], - {"timeout_ms": "20000"}, - ) - - headers = { - b"Content-Length": [b"%d" % (len(self.test_image.data))], - } - - if include_content_type: - headers[b"Content-Type"] = [self.test_image.content_type] - - if content_disposition: - headers[b"Content-Disposition"] = [content_disposition] - - self.fetches[0][0].callback( - (self.test_image.data, (len(self.test_image.data), headers, b"{}")) - ) - - self.pump() - self.assertEqual(channel.code, 200) - - return channel - - def test_handle_missing_content_type(self) -> None: - channel = self._req( - b"attachment; filename=out" + self.test_image.extension, - include_content_type=False, - ) - headers = channel.headers - self.assertEqual(channel.code, 200) - self.assertEqual( - headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"] - ) - - def test_disposition_filename_ascii(self) -> None: - """ - If the filename is filename=<ascii> then Synapse will decode it as an - ASCII string, and use filename= in the response. - """ - channel = self._req(b"attachment; filename=out" + self.test_image.extension) - - headers = channel.headers - self.assertEqual( - headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] - ) - self.assertEqual( - headers.getRawHeaders(b"Content-Disposition"), - [ - (b"inline" if self.test_image.is_inline else b"attachment") - + b"; filename=out" - + self.test_image.extension - ], - ) - - def test_disposition_filenamestar_utf8escaped(self) -> None: - """ - If the filename is filename=*utf8''<utf8 escaped> then Synapse will - correctly decode it as the UTF-8 string, and use filename* in the - response. - """ - filename = parse.quote("\u2603".encode()).encode("ascii") - channel = self._req( - b"attachment; filename*=utf-8''" + filename + self.test_image.extension - ) - - headers = channel.headers - self.assertEqual( - headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] - ) - self.assertEqual( - headers.getRawHeaders(b"Content-Disposition"), - [ - (b"inline" if self.test_image.is_inline else b"attachment") - + b"; filename*=utf-8''" - + filename - + self.test_image.extension - ], - ) - - def test_disposition_none(self) -> None: - """ - If there is no filename, Content-Disposition should only - be a disposition type. - """ - channel = self._req(None) - - headers = channel.headers - self.assertEqual( - headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] - ) - self.assertEqual( - headers.getRawHeaders(b"Content-Disposition"), - [b"inline" if self.test_image.is_inline else b"attachment"], - ) - - def test_x_robots_tag_header(self) -> None: - """ - Tests that the `X-Robots-Tag` header is present, which informs web crawlers - to not index, archive, or follow links in media. - """ - channel = self._req(b"attachment; filename=out" + self.test_image.extension) - - headers = channel.headers - self.assertEqual( - headers.getRawHeaders(b"X-Robots-Tag"), - [b"noindex, nofollow, noarchive, noimageindex"], - ) - - def test_cross_origin_resource_policy_header(self) -> None: - """ - Test that the Cross-Origin-Resource-Policy header is set to "cross-origin" - allowing web clients to embed media from the downloads API. - """ - channel = self._req(b"attachment; filename=out" + self.test_image.extension) - - headers = channel.headers - - self.assertEqual( - headers.getRawHeaders(b"Cross-Origin-Resource-Policy"), - [b"cross-origin"], - ) - - def test_unknown_federation_endpoint(self) -> None: - """ - Test that if the download request to remote federation endpoint returns a 404 - we fall back to the _matrix/media endpoint - """ - channel = self.make_request( - "GET", - f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}", - shorthand=False, - await_result=False, - access_token=self.tok, - ) - self.pump() - - # We've made one fetch, to example.com, using the media URL, and asking - # the other server not to do a remote fetch - self.assertEqual(len(self.fetches), 1) - self.assertEqual(self.fetches[0][1], "example.com") - self.assertEqual( - self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}" - ) - - # The result which says the endpoint is unknown. - unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}' - self.fetches[0][0].errback( - HttpResponseException(404, "NOT FOUND", unknown_endpoint) - ) - - self.pump() - - # There should now be another request to the _matrix/media/v3/download URL. - self.assertEqual(len(self.fetches), 2) - self.assertEqual(self.fetches[1][1], "example.com") - self.assertEqual( - self.fetches[1][2], - f"/_matrix/media/v3/download/example.com/{self.media_id}", - ) - - headers = { - b"Content-Length": [b"%d" % (len(self.test_image.data))], - } - - self.fetches[1][0].callback( - (self.test_image.data, (len(self.test_image.data), headers)) - ) - - self.pump() - self.assertEqual(channel.code, 200) - - def test_thumbnail_crop(self) -> None: - """Test that a cropped remote thumbnail is available.""" - self._test_thumbnail( - "crop", - self.test_image.expected_cropped, - expected_found=self.test_image.expected_found, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - def test_thumbnail_scale(self) -> None: - """Test that a scaled remote thumbnail is available.""" - self._test_thumbnail( - "scale", - self.test_image.expected_scaled, - expected_found=self.test_image.expected_found, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - def test_invalid_type(self) -> None: - """An invalid thumbnail type is never available.""" - self._test_thumbnail( - "invalid", - None, - expected_found=False, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - @unittest.override_config( - {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} - ) - def test_no_thumbnail_crop(self) -> None: - """ - Override the config to generate only scaled thumbnails, but request a cropped one. - """ - self._test_thumbnail( - "crop", - None, - expected_found=False, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - @unittest.override_config( - {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} - ) - def test_no_thumbnail_scale(self) -> None: - """ - Override the config to generate only cropped thumbnails, but request a scaled one. - """ - self._test_thumbnail( - "scale", - None, - expected_found=False, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - def test_thumbnail_repeated_thumbnail(self) -> None: - """Test that fetching the same thumbnail works, and deleting the on disk - thumbnail regenerates it. - """ - self._test_thumbnail( - "scale", - self.test_image.expected_scaled, - expected_found=self.test_image.expected_found, - unable_to_thumbnail=self.test_image.unable_to_thumbnail, - ) - - if not self.test_image.expected_found: - return - - # Fetching again should work, without re-requesting the image from the - # remote. - params = "?width=32&height=32&method=scale" - channel = self.make_request( - "GET", - f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}", - shorthand=False, - await_result=False, - access_token=self.tok, - ) - self.pump() - - self.assertEqual(channel.code, 200) - if self.test_image.expected_scaled: - self.assertEqual( - channel.result["body"], - self.test_image.expected_scaled, - channel.result["body"], - ) - - # Deleting the thumbnail on disk then re-requesting it should work as - # Synapse should regenerate missing thumbnails. - info = self.get_success( - self.store.get_cached_remote_media(self.remote, self.media_id) - ) - assert info is not None - file_id = info.filesystem_id - - thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( - self.remote, file_id - ) - shutil.rmtree(thumbnail_dir, ignore_errors=True) - - channel = self.make_request( - "GET", - f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}", - shorthand=False, - await_result=False, - access_token=self.tok, - ) - self.pump() - - self.assertEqual(channel.code, 200) - if self.test_image.expected_scaled: - self.assertEqual( - channel.result["body"], - self.test_image.expected_scaled, - channel.result["body"], - ) - - def _test_thumbnail( - self, - method: str, - expected_body: Optional[bytes], - expected_found: bool, - unable_to_thumbnail: bool = False, - ) -> None: - """Test the given thumbnailing method works as expected. - - Args: - method: The thumbnailing method to use (crop, scale). - expected_body: The expected bytes from thumbnailing, or None if - test should just check for a valid image. - expected_found: True if the file should exist on the server, or False if - a 404/400 is expected. - unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or - False if the thumbnailing should succeed or a normal 404 is expected. - """ - - params = "?width=32&height=32&method=" + method - channel = self.make_request( - "GET", - f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}", - shorthand=False, - await_result=False, - access_token=self.tok, - ) - self.pump() - headers = { - b"Content-Length": [b"%d" % (len(self.test_image.data))], - b"Content-Type": [self.test_image.content_type], - } - self.fetches[0][0].callback( - (self.test_image.data, (len(self.test_image.data), headers)) - ) - self.pump() - if expected_found: - self.assertEqual(channel.code, 200) - - self.assertEqual( - channel.headers.getRawHeaders(b"Cross-Origin-Resource-Policy"), - [b"cross-origin"], - ) - - if expected_body is not None: - self.assertEqual( - channel.result["body"], expected_body, channel.result["body"] - ) - else: - # ensure that the result is at least some valid image - Image.open(io.BytesIO(channel.result["body"])) - elif unable_to_thumbnail: - # A 400 with a JSON body. - self.assertEqual(channel.code, 400) - self.assertEqual( - channel.json_body, - { - "errcode": "M_UNKNOWN", - "error": "Cannot find any thumbnails for the requested media ('/_matrix/client/v1/media/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", - }, - ) - else: - # A 404 with a JSON body. - self.assertEqual(channel.code, 404) - self.assertEqual( - channel.json_body, - { - "errcode": "M_NOT_FOUND", - "error": "Not found '/_matrix/client/v1/media/thumbnail/example.com/12345'", - }, - ) - - @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)]) - def test_same_quality(self, method: str, desired_size: int) -> None: - """Test that choosing between thumbnails with the same quality rating succeeds. - - We are not particular about which thumbnail is chosen.""" - - content_type = self.test_image.content_type.decode() - media_repo = self.hs.get_media_repository() - thumbnail_provider = ThumbnailProvider( - self.hs, media_repo, media_repo.media_storage - ) - - self.assertIsNotNone( - thumbnail_provider._select_thumbnail( - desired_width=desired_size, - desired_height=desired_size, - desired_method=method, - desired_type=content_type, - # Provide two identical thumbnails which are guaranteed to have the same - # quality rating. - thumbnail_infos=[ - ThumbnailInfo( - width=32, - height=32, - method=method, - type=content_type, - length=256, - ), - ThumbnailInfo( - width=32, - height=32, - method=method, - type=content_type, - length=256, - ), - ], - file_id=f"image{self.test_image.extension.decode()}", - url_cache=False, - server_name=None, - ) - ) - - -configs = [ - {"extra_config": {"dynamic_thumbnails": True}}, - {"extra_config": {"dynamic_thumbnails": False}}, -] - - -@parameterized_class(configs) -class AuthenticatedMediaTestCase(unittest.HomeserverTestCase): - extra_config: Dict[str, Any] - servlets = [ - media.register_servlets, - login.register_servlets, - admin.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - - self.clock = clock - self.storage_path = self.mktemp() - self.media_store_path = self.mktemp() - os.mkdir(self.storage_path) - os.mkdir(self.media_store_path) - config["media_store_path"] = self.media_store_path - config["enable_authenticated_media"] = True - - provider_config = { - "module": "synapse.media.storage_provider.FileStorageProviderBackend", - "store_local": True, - "store_synchronous": False, - "store_remote": True, - "config": {"directory": self.storage_path}, - } - - config["media_storage_providers"] = [provider_config] - config.update(self.extra_config) - - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.repo = hs.get_media_repository() - self.client = hs.get_federation_http_client() - self.store = hs.get_datastores().main - self.user = self.register_user("user", "pass") - self.tok = self.login("user", "pass") - - def create_resource_dict(self) -> Dict[str, Resource]: - resources = super().create_resource_dict() - resources["/_matrix/media"] = self.hs.get_media_repository_resource() - return resources - - def test_authenticated_media(self) -> None: - # upload some local media with authentication on - channel = self.make_request( - "POST", - "_matrix/media/v3/upload?filename=test_png_upload", - SMALL_PNG, - self.tok, - shorthand=False, - content_type=b"image/png", - custom_headers=[("Content-Length", str(67))], - ) - self.assertEqual(channel.code, 200) - res = channel.json_body.get("content_uri") - assert res is not None - uri = res.split("mxc://")[1] - - # request media over authenticated endpoint, should be found - channel2 = self.make_request( - "GET", - f"_matrix/client/v1/media/download/{uri}", - access_token=self.tok, - shorthand=False, - ) - self.assertEqual(channel2.code, 200) - - # request same media over unauthenticated media, should raise 404 not found - channel3 = self.make_request( - "GET", f"_matrix/media/v3/download/{uri}", shorthand=False - ) - self.assertEqual(channel3.code, 404) - - # check thumbnails as well - params = "?width=32&height=32&method=crop" - channel4 = self.make_request( - "GET", - f"/_matrix/client/v1/media/thumbnail/{uri}{params}", - shorthand=False, - access_token=self.tok, - ) - self.assertEqual(channel4.code, 200) - - params = "?width=32&height=32&method=crop" - channel5 = self.make_request( - "GET", - f"/_matrix/media/r0/thumbnail/{uri}{params}", - shorthand=False, - access_token=self.tok, - ) - self.assertEqual(channel5.code, 404) - - # Inject a piece of remote media. - file_id = "abcdefg12345" - file_info = FileInfo(server_name="lonelyIsland", file_id=file_id) - - media_storage = self.hs.get_media_repository().media_storage - - ctx = media_storage.store_into_file(file_info) - (f, fname) = self.get_success(ctx.__aenter__()) - f.write(SMALL_PNG) - self.get_success(ctx.__aexit__(None, None, None)) - - # we write the authenticated status when storing media, so this should pick up - # config and authenticate the media - self.get_success( - self.store.store_cached_remote_media( - origin="lonelyIsland", - media_id="52", - media_type="image/png", - media_length=1, - time_now_ms=self.clock.time_msec(), - upload_name="remote_test.png", - filesystem_id=file_id, - ) - ) - - # ensure we have thumbnails for the non-dynamic code path - if self.extra_config == {"dynamic_thumbnails": False}: - self.get_success( - self.repo._generate_thumbnails( - "lonelyIsland", "52", file_id, "image/png" - ) - ) - - channel6 = self.make_request( - "GET", - "_matrix/client/v1/media/download/lonelyIsland/52", - access_token=self.tok, - shorthand=False, - ) - self.assertEqual(channel6.code, 200) - - channel7 = self.make_request( - "GET", f"_matrix/media/v3/download/{uri}", shorthand=False - ) - self.assertEqual(channel7.code, 404) - - params = "?width=32&height=32&method=crop" - channel8 = self.make_request( - "GET", - f"/_matrix/client/v1/media/thumbnail/lonelyIsland/52{params}", - shorthand=False, - access_token=self.tok, - ) - self.assertEqual(channel8.code, 200) - - channel9 = self.make_request( - "GET", - f"/_matrix/media/r0/thumbnail/lonelyIsland/52{params}", - shorthand=False, - access_token=self.tok, - ) - self.assertEqual(channel9.code, 404) - - # Inject a piece of local media that isn't authenticated - file_id = "abcdefg123456" - file_info = FileInfo(None, file_id=file_id) - - ctx = media_storage.store_into_file(file_info) - (f, fname) = self.get_success(ctx.__aenter__()) - f.write(SMALL_PNG) - self.get_success(ctx.__aexit__(None, None, None)) - - self.get_success( - self.store.db_pool.simple_insert( - "local_media_repository", - { - "media_id": "abcdefg123456", - "media_type": "image/png", - "created_ts": self.clock.time_msec(), - "upload_name": "test_local", - "media_length": 1, - "user_id": "someone", - "url_cache": None, - "authenticated": False, - }, - desc="store_local_media", - ) - ) - - # check that unauthenticated media is still available over both endpoints - channel9 = self.make_request( - "GET", - "/_matrix/client/v1/media/download/test/abcdefg123456", - shorthand=False, - access_token=self.tok, - ) - self.assertEqual(channel9.code, 200) - - channel10 = self.make_request( - "GET", - "/_matrix/media/r0/download/test/abcdefg123456", - shorthand=False, - access_token=self.tok, - ) - self.assertEqual(channel10.code, 200) diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
index f8a56c80ca..99d0feb10d 100644 --- a/tests/rest/client/test_models.py +++ b/tests/rest/client/test_models.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -24,7 +23,7 @@ from typing import TYPE_CHECKING from typing_extensions import Literal from synapse._pydantic_compat import HAS_PYDANTIC_V2 -from synapse.types.rest.client import EmailRequestTokenBody +from synapse.rest.client.models import EmailRequestTokenBody if TYPE_CHECKING or HAS_PYDANTIC_V2: from pydantic.v1 import BaseModel, ValidationError diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py
index 637722ca0a..e25f4e4175 100644 --- a/tests/rest/client/test_mutual_rooms.py +++ b/tests/rest/client/test_mutual_rooms.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 Half-Shot # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py
index e4b0455ce8..8f1cbcd5db 100644 --- a/tests/rest/client/test_notifications.py +++ b/tests/rest/client/test_notifications.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -18,7 +17,6 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import List, Optional, Tuple from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -49,14 +47,6 @@ class HTTPPusherTests(HomeserverTestCase): self.sync_handler = homeserver.get_sync_handler() self.auth_handler = homeserver.get_auth_handler() - self.user_id = self.register_user("user", "pass") - self.access_token = self.login("user", "pass") - self.other_user_id = self.register_user("otheruser", "pass") - self.other_access_token = self.login("otheruser", "pass") - - # Create a room - self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. fed_transport_client = Mock(spec=["send_transaction"]) @@ -70,22 +60,32 @@ class HTTPPusherTests(HomeserverTestCase): """ Local users will get notified for invites """ + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + # Check we start with no pushes - self._request_notifications(from_token=None, limit=1, expected_count=0) + channel = self.make_request( + "GET", + "/notifications", + access_token=other_access_token, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(len(channel.json_body["notifications"]), 0, channel.json_body) # Send an invite - self.helper.invite( - room=self.room_id, - src=self.user_id, - targ=self.other_user_id, - tok=self.access_token, - ) + self.helper.invite(room=room, src=user_id, targ=other_user_id, tok=access_token) # We should have a notification now channel = self.make_request( "GET", "/notifications", - access_token=self.other_access_token, + access_token=other_access_token, ) self.assertEqual(channel.code, 200) self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) @@ -94,139 +94,3 @@ class HTTPPusherTests(HomeserverTestCase): "invite", channel.json_body, ) - - def test_pagination_of_notifications(self) -> None: - """ - Check that pagination of notifications works. - """ - # Check we start with no pushes - self._request_notifications(from_token=None, limit=1, expected_count=0) - - # Send an invite and have the other user join the room. - self.helper.invite( - room=self.room_id, - src=self.user_id, - targ=self.other_user_id, - tok=self.access_token, - ) - self.helper.join(self.room_id, self.other_user_id, tok=self.other_access_token) - - # Send 5 messages in the room and note down their event IDs. - sent_event_ids = [] - for _ in range(5): - resp = self.helper.send_event( - self.room_id, - "m.room.message", - {"body": "honk", "msgtype": "m.text"}, - tok=self.access_token, - ) - sent_event_ids.append(resp["event_id"]) - - # We expect to get notifications for messages in reverse order. - # So reverse this list of event IDs to make it easier to compare - # against later. - sent_event_ids.reverse() - - # We should have a few notifications now. Let's try and fetch the first 2. - notification_event_ids, _ = self._request_notifications( - from_token=None, limit=2, expected_count=2 - ) - - # Check we got the expected event IDs back. - self.assertEqual(notification_event_ids, sent_event_ids[:2]) - - # Try requesting again without a 'from' query parameter. We should get the - # same two notifications back. - notification_event_ids, next_token = self._request_notifications( - from_token=None, limit=2, expected_count=2 - ) - self.assertEqual(notification_event_ids, sent_event_ids[:2]) - - # Ask for the next 5 notifications, though there should only be - # 4 remaining; the next 3 messages and the invite. - # - # We need to use the "next_token" from the response as the "from" - # query parameter in the next request in order to paginate. - notification_event_ids, next_token = self._request_notifications( - from_token=next_token, limit=5, expected_count=4 - ) - # Ensure we chop off the invite on the end. - notification_event_ids = notification_event_ids[:-1] - self.assertEqual(notification_event_ids, sent_event_ids[2:]) - - def _request_notifications( - self, from_token: Optional[str], limit: int, expected_count: int - ) -> Tuple[List[str], str]: - """ - Make a request to /notifications to get the latest events to be notified about. - - Only the event IDs are returned. The request is made by the "other user". - - Args: - from_token: An optional starting parameter. - limit: The maximum number of results to return. - expected_count: The number of events to expect in the response. - - Returns: - A list of event IDs that the client should be notified about. - Events are returned newest-first. - """ - # Construct the request path. - path = f"/notifications?limit={limit}" - if from_token is not None: - path += f"&from={from_token}" - - channel = self.make_request( - "GET", - path, - access_token=self.other_access_token, - ) - - self.assertEqual(channel.code, 200) - self.assertEqual( - len(channel.json_body["notifications"]), expected_count, channel.json_body - ) - - # Extract the necessary data from the response. - next_token = channel.json_body["next_token"] - event_ids = [ - event["event"]["event_id"] for event in channel.json_body["notifications"] - ] - - return event_ids, next_token - - def test_parameters(self) -> None: - """ - Test that appropriate errors are returned when query parameters are malformed. - """ - # Test that no parameters are required. - channel = self.make_request( - "GET", - "/notifications", - access_token=self.other_access_token, - ) - self.assertEqual(channel.code, 200) - - # Test that limit cannot be negative - channel = self.make_request( - "GET", - "/notifications?limit=-1", - access_token=self.other_access_token, - ) - self.assertEqual(channel.code, 400) - - # Test that the 'limit' parameter must be an integer. - channel = self.make_request( - "GET", - "/notifications?limit=foobar", - access_token=self.other_access_token, - ) - self.assertEqual(channel.code, 400) - - # Test that the 'from' parameter must be an integer. - channel = self.make_request( - "GET", - "/notifications?from=osborne", - access_token=self.other_access_token, - ) - self.assertEqual(channel.code, 400) diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py
index f0ef733f7b..08d1cf56f2 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
index 1584c2e96c..634cda9262 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index f98f3f77aa..b9852928c0 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py
index 9da0e7982f..0ee66b6a84 100644 --- a/tests/rest/client/test_push_rule_attrs.py +++ b/tests/rest/client/test_push_rule_attrs.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_read_marker.py b/tests/rest/client/test_read_marker.py
index 0b4ad685b3..ed0ac2af6e 100644 --- a/tests/rest/client/test_read_marker.py +++ b/tests/rest/client/test_read_marker.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 Beeper # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -78,7 +77,7 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - f"/rooms/{room_id}/read_markers", + "/rooms/!abc:beep/read_markers", content={ "m.fully_read": event_id_1, }, @@ -90,7 +89,7 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase): event_id_2 = send_message() channel = self.make_request( "POST", - f"/rooms/{room_id}/read_markers", + "/rooms/!abc:beep/read_markers", content={ "m.fully_read": event_id_2, }, @@ -123,7 +122,7 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - f"/rooms/{room_id}/read_markers", + "/rooms/!abc:beep/read_markers", content={ "m.fully_read": event_id_1, }, @@ -142,7 +141,7 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase): event_id_2 = send_message() channel = self.make_request( "POST", - f"/rooms/{room_id}/read_markers", + "/rooms/!abc:beep/read_markers", content={ "m.fully_read": event_id_2, }, diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py
index f0648289f1..9fe49a12f6 100644 --- a/tests/rest/client/test_receipts.py +++ b/tests/rest/client/test_receipts.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index b25e184786..2afb3f065a 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 694f143eff..13f27d13c9 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py
@@ -1,8 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -22,7 +20,6 @@ import datetime import os from typing import Any, Dict, List, Tuple -from unittest.mock import AsyncMock import pkg_resources @@ -43,7 +40,6 @@ from synapse.types import JsonDict from synapse.util import Clock from tests import unittest -from tests.server import ThreadedMemoryReactorClock from tests.unittest import override_config @@ -60,13 +56,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): config["allow_guest_access"] = True return config - def make_homeserver( - self, reactor: ThreadedMemoryReactorClock, clock: Clock - ) -> HomeServer: - hs = super().make_homeserver(reactor, clock) - hs.get_send_email_handler()._sendmail = AsyncMock() - return hs - def test_POST_appservice_registration_valid(self) -> None: user_id = "@as_user_kermit:test" as_token = "i_am_an_app_service" diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index f5a7602d0a..e0529d88c5 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -35,6 +34,7 @@ from synapse.util import Clock from tests import unittest from tests.server import FakeChannel from tests.test_utils.event_injection import inject_event +from tests.unittest import override_config class BaseRelationsTestCase(unittest.HomeserverTestCase): @@ -956,6 +956,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): class RecursiveRelationTestCase(BaseRelationsTestCase): + @override_config({"experimental_features": {"msc3981_recurse_relations": True}}) def test_recursive_relations(self) -> None: """Generate a complex, multi-level relationship tree and query it.""" # Create a thread with a few messages in it. @@ -1001,7 +1002,7 @@ class RecursiveRelationTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}" - "?dir=f&limit=20&recurse=true", + "?dir=f&limit=20&org.matrix.msc3981.recurse=true", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1022,6 +1023,7 @@ class RecursiveRelationTestCase(BaseRelationsTestCase): ], ) + @override_config({"experimental_features": {"msc3981_recurse_relations": True}}) def test_recursive_relations_with_filter(self) -> None: """The event_type and rel_type still apply.""" # Create a thread with a few messages in it. @@ -1049,7 +1051,7 @@ class RecursiveRelationTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}" - "?dir=f&limit=20&recurse=true", + "?dir=f&limit=20&org.matrix.msc3981.recurse=true", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -1062,7 +1064,7 @@ class RecursiveRelationTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}/m.reaction" - "?dir=f&limit=20&recurse=true", + "?dir=f&limit=20&org.matrix.msc3981.recurse=true", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index 0ab754a11a..07e45f14f9 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py
@@ -1,8 +1,7 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023-2024 New Vector, Ltd +# Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -19,23 +18,16 @@ # # -from typing import Dict -from urllib.parse import urlparse - from twisted.test.proto_helpers import MemoryReactor -from twisted.web.resource import Resource from synapse.rest.client import rendezvous -from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource from synapse.server import HomeServer from synapse.util import Clock from tests import unittest from tests.unittest import override_config -from tests.utils import HAS_AUTHLIB -msc3886_endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" -msc4108_endpoint = "/_matrix/client/unstable/org.matrix.msc4108/rendezvous" +endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" class RendezvousServletTestCase(unittest.HomeserverTestCase): @@ -47,430 +39,12 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver() return self.hs - def create_resource_dict(self) -> Dict[str, Resource]: - return { - **super().create_resource_dict(), - "/_synapse/client/rendezvous": MSC4108RendezvousSessionResource(self.hs), - } - def test_disabled(self) -> None: - channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None) - self.assertEqual(channel.code, 404) - channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None) + channel = self.make_request("POST", endpoint, {}, access_token=None) self.assertEqual(channel.code, 404) @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}}) - def test_msc3886_redirect(self) -> None: - channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None) + def test_redirect(self) -> None: + channel = self.make_request("POST", endpoint, {}, access_token=None) self.assertEqual(channel.code, 307) self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"]) - - @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") - @override_config( - { - "disable_registration": True, - "experimental_features": { - "msc4108_delegation_endpoint": "https://asd", - "msc3861": { - "enabled": True, - "issuer": "https://issuer", - "client_id": "client_id", - "client_auth_method": "client_secret_post", - "client_secret": "client_secret", - "admin_token": "admin_token_value", - }, - }, - } - ) - def test_msc4108_delegation(self) -> None: - channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None) - self.assertEqual(channel.code, 307) - self.assertEqual(channel.headers.getRawHeaders("Location"), ["https://asd"]) - - @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") - @override_config( - { - "disable_registration": True, - "experimental_features": { - "msc4108_enabled": True, - "msc3861": { - "enabled": True, - "issuer": "https://issuer", - "client_id": "client_id", - "client_auth_method": "client_secret_post", - "client_secret": "client_secret", - "admin_token": "admin_token_value", - }, - }, - } - ) - def test_msc4108(self) -> None: - """ - Test the MSC4108 rendezvous endpoint, including: - - Creating a session - - Getting the data back - - Updating the data - - Deleting the data - - ETag handling - """ - # We can post arbitrary data to the endpoint - channel = self.make_request( - "POST", - msc4108_endpoint, - "foo=bar", - content_type=b"text/plain", - access_token=None, - ) - self.assertEqual(channel.code, 201) - self.assertSubstring("/_synapse/client/rendezvous/", channel.json_body["url"]) - headers = dict(channel.headers.getAllRawHeaders()) - self.assertIn(b"ETag", headers) - self.assertIn(b"Expires", headers) - self.assertEqual(headers[b"Content-Type"], [b"application/json"]) - self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"]) - self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"]) - self.assertEqual(headers[b"Cache-Control"], [b"no-store"]) - self.assertEqual(headers[b"Pragma"], [b"no-cache"]) - self.assertIn("url", channel.json_body) - self.assertTrue(channel.json_body["url"].startswith("https://")) - - url = urlparse(channel.json_body["url"]) - session_endpoint = url.path - etag = headers[b"ETag"][0] - - # We can get the data back - channel = self.make_request( - "GET", - session_endpoint, - access_token=None, - ) - - self.assertEqual(channel.code, 200) - headers = dict(channel.headers.getAllRawHeaders()) - self.assertEqual(headers[b"ETag"], [etag]) - self.assertIn(b"Expires", headers) - self.assertEqual(headers[b"Content-Type"], [b"text/plain"]) - self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"]) - self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"]) - self.assertEqual(headers[b"Cache-Control"], [b"no-store"]) - self.assertEqual(headers[b"Pragma"], [b"no-cache"]) - self.assertEqual(channel.text_body, "foo=bar") - - # We can make sure the data hasn't changed - channel = self.make_request( - "GET", - session_endpoint, - access_token=None, - custom_headers=[("If-None-Match", etag)], - ) - - self.assertEqual(channel.code, 304) - - # We can update the data - channel = self.make_request( - "PUT", - session_endpoint, - "foo=baz", - content_type=b"text/plain", - access_token=None, - custom_headers=[("If-Match", etag)], - ) - - self.assertEqual(channel.code, 202) - headers = dict(channel.headers.getAllRawHeaders()) - old_etag = etag - new_etag = headers[b"ETag"][0] - - # If we try to update it again with the old etag, it should fail - channel = self.make_request( - "PUT", - session_endpoint, - "bar=baz", - content_type=b"text/plain", - access_token=None, - custom_headers=[("If-Match", old_etag)], - ) - - self.assertEqual(channel.code, 412) - self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN") - self.assertEqual( - channel.json_body["org.matrix.msc4108.errcode"], "M_CONCURRENT_WRITE" - ) - - # If we try to get with the old etag, we should get the updated data - channel = self.make_request( - "GET", - session_endpoint, - access_token=None, - custom_headers=[("If-None-Match", old_etag)], - ) - - self.assertEqual(channel.code, 200) - headers = dict(channel.headers.getAllRawHeaders()) - self.assertEqual(headers[b"ETag"], [new_etag]) - self.assertEqual(channel.text_body, "foo=baz") - - # We can delete the data - channel = self.make_request( - "DELETE", - session_endpoint, - access_token=None, - ) - - self.assertEqual(channel.code, 204) - - # If we try to get the data again, it should fail - channel = self.make_request( - "GET", - session_endpoint, - access_token=None, - ) - - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - - @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") - @override_config( - { - "disable_registration": True, - "experimental_features": { - "msc4108_enabled": True, - "msc3861": { - "enabled": True, - "issuer": "https://issuer", - "client_id": "client_id", - "client_auth_method": "client_secret_post", - "client_secret": "client_secret", - "admin_token": "admin_token_value", - }, - }, - } - ) - def test_msc4108_expiration(self) -> None: - """ - Test that entries are evicted after a TTL. - """ - # Start a new session - channel = self.make_request( - "POST", - msc4108_endpoint, - "foo=bar", - content_type=b"text/plain", - access_token=None, - ) - self.assertEqual(channel.code, 201) - session_endpoint = urlparse(channel.json_body["url"]).path - - # Sanity check that we can get the data back - channel = self.make_request( - "GET", - session_endpoint, - access_token=None, - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.text_body, "foo=bar") - - # Advance the clock, TTL of entries is 1 minute - self.reactor.advance(60) - - # Get the data back, it should be gone - channel = self.make_request( - "GET", - session_endpoint, - access_token=None, - ) - self.assertEqual(channel.code, 404) - - @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") - @override_config( - { - "disable_registration": True, - "experimental_features": { - "msc4108_enabled": True, - "msc3861": { - "enabled": True, - "issuer": "https://issuer", - "client_id": "client_id", - "client_auth_method": "client_secret_post", - "client_secret": "client_secret", - "admin_token": "admin_token_value", - }, - }, - } - ) - def test_msc4108_capacity(self) -> None: - """ - Test that a capacity limit is enforced on the rendezvous sessions, as old - entries are evicted at an interval when the limit is reached. - """ - # Start a new session - channel = self.make_request( - "POST", - msc4108_endpoint, - "foo=bar", - content_type=b"text/plain", - access_token=None, - ) - self.assertEqual(channel.code, 201) - session_endpoint = urlparse(channel.json_body["url"]).path - - # Sanity check that we can get the data back - channel = self.make_request( - "GET", - session_endpoint, - access_token=None, - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.text_body, "foo=bar") - - # Start a lot of new sessions - for _ in range(100): - channel = self.make_request( - "POST", - msc4108_endpoint, - "foo=bar", - content_type=b"text/plain", - access_token=None, - ) - self.assertEqual(channel.code, 201) - - # Get the data back, it should still be there, as the eviction hasn't run yet - channel = self.make_request( - "GET", - session_endpoint, - access_token=None, - ) - - self.assertEqual(channel.code, 200) - - # Advance the clock, as it will trigger the eviction - self.reactor.advance(1) - - # Get the data back, it should be gone - channel = self.make_request( - "GET", - session_endpoint, - access_token=None, - ) - - @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") - @override_config( - { - "disable_registration": True, - "experimental_features": { - "msc4108_enabled": True, - "msc3861": { - "enabled": True, - "issuer": "https://issuer", - "client_id": "client_id", - "client_auth_method": "client_secret_post", - "client_secret": "client_secret", - "admin_token": "admin_token_value", - }, - }, - } - ) - def test_msc4108_hard_capacity(self) -> None: - """ - Test that a hard capacity limit is enforced on the rendezvous sessions, as old - entries are evicted immediately when the limit is reached. - """ - # Start a new session - channel = self.make_request( - "POST", - msc4108_endpoint, - "foo=bar", - content_type=b"text/plain", - access_token=None, - ) - self.assertEqual(channel.code, 201) - session_endpoint = urlparse(channel.json_body["url"]).path - # We advance the clock to make sure that this entry is the "lowest" in the session list - self.reactor.advance(1) - - # Sanity check that we can get the data back - channel = self.make_request( - "GET", - session_endpoint, - access_token=None, - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.text_body, "foo=bar") - - # Start a lot of new sessions - for _ in range(200): - channel = self.make_request( - "POST", - msc4108_endpoint, - "foo=bar", - content_type=b"text/plain", - access_token=None, - ) - self.assertEqual(channel.code, 201) - - # Get the data back, it should already be gone as we hit the hard limit - channel = self.make_request( - "GET", - session_endpoint, - access_token=None, - ) - - self.assertEqual(channel.code, 404) - - @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") - @override_config( - { - "disable_registration": True, - "experimental_features": { - "msc4108_enabled": True, - "msc3861": { - "enabled": True, - "issuer": "https://issuer", - "client_id": "client_id", - "client_auth_method": "client_secret_post", - "client_secret": "client_secret", - "admin_token": "admin_token_value", - }, - }, - } - ) - def test_msc4108_content_type(self) -> None: - """ - Test that the content-type is restricted to text/plain. - """ - # We cannot post invalid content-type arbitrary data to the endpoint - channel = self.make_request( - "POST", - msc4108_endpoint, - "foo=bar", - content_is_form=True, - access_token=None, - ) - self.assertEqual(channel.code, 400) - self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") - - # Make a valid request - channel = self.make_request( - "POST", - msc4108_endpoint, - "foo=bar", - content_type=b"text/plain", - access_token=None, - ) - self.assertEqual(channel.code, 201) - url = urlparse(channel.json_body["url"]) - session_endpoint = url.path - headers = dict(channel.headers.getAllRawHeaders()) - etag = headers[b"ETag"][0] - - # We can't update the data with invalid content-type - channel = self.make_request( - "PUT", - session_endpoint, - "foo=baz", - content_is_form=True, - access_token=None, - custom_headers=[("If-Match", etag)], - ) - self.assertEqual(channel.code, 400) - self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") diff --git a/tests/rest/client/test_reporting.py b/tests/rest/client/test_report_event.py
index 009deb9cb0..41c03b6f68 100644 --- a/tests/rest/client/test_reporting.py +++ b/tests/rest/client/test_report_event.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 Callum Brown # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -22,7 +21,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.rest.client import login, reporting, room +from synapse.rest.client import login, report_event, room from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -35,7 +34,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets, login.register_servlets, room.register_servlets, - reporting.register_servlets, + report_event.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -139,92 +138,3 @@ class ReportEventTestCase(unittest.HomeserverTestCase): "POST", self.report_path, data, access_token=self.other_user_tok ) self.assertEqual(response_status, channel.code, msg=channel.result["body"]) - - -class ReportRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - reporting.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.other_user = self.register_user("user", "pass") - self.other_user_tok = self.login("user", "pass") - - self.room_id = self.helper.create_room_as( - self.other_user, tok=self.other_user_tok, is_public=True - ) - self.report_path = ( - f"/_matrix/client/unstable/org.matrix.msc4151/rooms/{self.room_id}/report" - ) - - @unittest.override_config( - { - "experimental_features": {"msc4151_enabled": True}, - } - ) - def test_reason_str(self) -> None: - data = {"reason": "this makes me sad"} - self._assert_status(200, data) - - @unittest.override_config( - { - "experimental_features": {"msc4151_enabled": True}, - } - ) - def test_no_reason(self) -> None: - data = {"not_reason": "for typechecking"} - self._assert_status(400, data) - - @unittest.override_config( - { - "experimental_features": {"msc4151_enabled": True}, - } - ) - def test_reason_nonstring(self) -> None: - data = {"reason": 42} - self._assert_status(400, data) - - @unittest.override_config( - { - "experimental_features": {"msc4151_enabled": True}, - } - ) - def test_reason_null(self) -> None: - data = {"reason": None} - self._assert_status(400, data) - - @unittest.override_config( - { - "experimental_features": {"msc4151_enabled": True}, - } - ) - def test_cannot_report_nonexistent_room(self) -> None: - """ - Tests that we don't accept event reports for rooms which do not exist. - """ - channel = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc4151/rooms/!bloop:example.org/report", - {"reason": "i am very sad"}, - access_token=self.other_user_tok, - shorthand=False, - ) - self.assertEqual(404, channel.code, msg=channel.result["body"]) - self.assertEqual( - "Room does not exist", - channel.json_body["error"], - msg=channel.result["body"], - ) - - def _assert_status(self, response_status: int, data: JsonDict) -> None: - channel = self.make_request( - "POST", - self.report_path, - data, - access_token=self.other_user_tok, - shorthand=False, - ) - self.assertEqual(response_status, channel.code, msg=channel.result["body"]) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 1e5a1b0a4d..09a5d64349 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -163,11 +163,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): ) self.assertEqual(2, len(events), "events retrieved from database") filtered_events = self.get_success( - filter_events_for_client( - storage_controllers, - self.user_id, - events, - ) + filter_events_for_client(storage_controllers, self.user_id, events) ) # We should only get one event back. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index c559dfda83..0e71cdcd88 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -1,9 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. -# Copyright 2017 Vector Creations Ltd -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -48,16 +45,7 @@ from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client import ( - account, - directory, - knock, - login, - profile, - register, - room, - sync, -) +from synapse.rest.client import account, directory, login, profile, register, room, sync from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util import Clock @@ -102,7 +90,6 @@ class RoomPermissionsTestCase(RoomBase): rmcreator_id = "@notme:red" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store_controllers = hs.get_storage_controllers() self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id self.uncreated_rmid = "!aa:test" @@ -492,23 +479,6 @@ class RoomPermissionsTestCase(RoomBase): expect_code=HTTPStatus.OK, ) - def test_default_call_invite_power_level(self) -> None: - pl_event = self.get_success( - self.store_controllers.state.get_current_state_event( - self.created_public_rmid, EventTypes.PowerLevels, "" - ) - ) - assert pl_event is not None - self.assertEqual(50, pl_event.content.get("m.call.invite")) - - private_pl_event = self.get_success( - self.store_controllers.state.get_current_state_event( - self.created_rmid, EventTypes.PowerLevels, "" - ) - ) - assert private_pl_event is not None - self.assertEqual(None, private_pl_event.content.get("m.call.invite")) - class RoomStateTestCase(RoomBase): """Tests /rooms/$room_id/state.""" @@ -742,7 +712,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(32, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -755,7 +725,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(35, channel.resource_usage.db_txn_count) + self.assertEqual(34, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id @@ -1163,7 +1133,6 @@ class RoomJoinTestCase(RoomBase): admin.register_servlets, login.register_servlets, room.register_servlets, - knock.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -1177,8 +1146,6 @@ class RoomJoinTestCase(RoomBase): self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) - self.store = hs.get_datastores().main - def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called and blocks room joins when needed. @@ -1252,9 +1219,9 @@ class RoomJoinTestCase(RoomBase): """ # Register a dummy callback. Make it allow all room joins for now. - return_value: Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes] = ( - synapse.module_api.NOT_SPAM - ) + return_value: Union[ + Literal["NOT_SPAM"], Tuple[Codes, dict], Codes + ] = synapse.module_api.NOT_SPAM async def user_may_join_room( userid: str, @@ -1329,57 +1296,6 @@ class RoomJoinTestCase(RoomBase): expect_additional_fields=return_value[1], ) - def test_suspended_user_cannot_join_room(self) -> None: - # set the user as suspended - self.get_success(self.store.set_user_suspended_status(self.user2, True)) - - channel = self.make_request( - "POST", f"/join/{self.room1}", access_token=self.tok2 - ) - self.assertEqual(channel.code, 403) - self.assertEqual( - channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" - ) - - channel = self.make_request( - "POST", f"/rooms/{self.room1}/join", access_token=self.tok2 - ) - self.assertEqual(channel.code, 403) - self.assertEqual( - channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" - ) - - def test_suspended_user_cannot_knock_on_room(self) -> None: - # set the user as suspended - self.get_success(self.store.set_user_suspended_status(self.user2, True)) - - channel = self.make_request( - "POST", - f"/_matrix/client/v3/knock/{self.room1}", - access_token=self.tok2, - content={}, - shorthand=False, - ) - self.assertEqual(channel.code, 403) - self.assertEqual( - channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" - ) - - def test_suspended_user_cannot_invite_to_room(self) -> None: - # set the user as suspended - self.get_success(self.store.set_user_suspended_status(self.user1, True)) - - # first user invites second user - channel = self.make_request( - "POST", - f"/rooms/{self.room1}/invite", - access_token=self.tok1, - content={"user_id": self.user2}, - ) - self.assertEqual( - channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" - ) - class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): servlets = [ @@ -1745,9 +1661,9 @@ class RoomMessagesTestCase(RoomBase): expected_fields: dict, ) -> None: class SpamCheck: - mock_return_value: Union[str, bool, Codes, Tuple[Codes, JsonDict], bool] = ( - "NOT_SPAM" - ) + mock_return_value: Union[ + str, bool, Codes, Tuple[Codes, JsonDict], bool + ] = "NOT_SPAM" mock_content: Optional[JsonDict] = None async def check_event_for_spam( @@ -2238,58 +2154,6 @@ class RoomMessageListTestCase(RoomBase): chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) - def test_room_message_filter_query_validation(self) -> None: - # Test json validation in (filter) query parameter. - # Does not test the validity of the filter, only the json validation. - - # Check Get with valid json filter parameter, expect 200. - valid_filter_str = '{"types": ["m.room.message"]}' - channel = self.make_request( - "GET", - f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={valid_filter_str}", - ) - - self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) - - # Check Get with invalid json filter parameter, expect 400 NOT_JSON. - invalid_filter_str = "}}}{}" - channel = self.make_request( - "GET", - f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={invalid_filter_str}", - ) - - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body) - self.assertEqual( - channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body - ) - - -class RoomMessageFilterTestCase(RoomBase): - """Tests /rooms/$room_id/messages REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.room_id = self.helper.create_room_as(self.user_id) - - def test_room_message_filter_wildcard(self) -> None: - # Send a first message in the room, which will be removed by the purge. - self.helper.send(self.room_id, "message 1", type="f.message.1") - self.helper.send(self.room_id, "message 1", type="f.message.2") - self.helper.send(self.room_id, "not returned in filter") - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&dir=b&filter=%s" - % ( - self.room_id, - json.dumps({"types": ["f.message.*"]}), - ), - ) - self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) - - chunk = channel.json_body["chunk"] - self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) - class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ @@ -3301,33 +3165,6 @@ class ContextTestCase(unittest.HomeserverTestCase): self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) self.assertEqual(events_after[1].get("content"), {}, events_after[1]) - def test_room_event_context_filter_query_validation(self) -> None: - # Test json validation in (filter) query parameter. - # Does not test the validity of the filter, only the json validation. - event_id = self.helper.send(self.room_id, "message 7", tok=self.tok)["event_id"] - - # Check Get with valid json filter parameter, expect 200. - valid_filter_str = '{"types": ["m.room.message"]}' - channel = self.make_request( - "GET", - f"/rooms/{self.room_id}/context/{event_id}?filter={valid_filter_str}", - access_token=self.tok, - ) - self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) - - # Check Get with invalid json filter parameter, expect 400 NOT_JSON. - invalid_filter_str = "}}}{}" - channel = self.make_request( - "GET", - f"/rooms/{self.room_id}/context/{event_id}?filter={invalid_filter_str}", - access_token=self.tok, - ) - - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body) - self.assertEqual( - channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body - ) - class RoomAliasListTestCase(unittest.HomeserverTestCase): servlets = [ @@ -3819,108 +3656,3 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): # Make sure the outlier event is not returned self.assertNotEqual(channel.json_body["event_id"], outlier_event.event_id) - - -class UserSuspensionTests(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - room.register_servlets, - profile.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.user1 = self.register_user("thomas", "hackme") - self.tok1 = self.login("thomas", "hackme") - - self.user2 = self.register_user("teresa", "hackme") - self.tok2 = self.login("teresa", "hackme") - - self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) - self.store = hs.get_datastores().main - - def test_suspended_user_cannot_send_message_to_room(self) -> None: - # set the user as suspended - self.get_success(self.store.set_user_suspended_status(self.user1, True)) - - channel = self.make_request( - "PUT", - f"/rooms/{self.room1}/send/m.room.message/1", - access_token=self.tok1, - content={"body": "hello", "msgtype": "m.text"}, - ) - self.assertEqual( - channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" - ) - - def test_suspended_user_cannot_change_profile_data(self) -> None: - # set the user as suspended - self.get_success(self.store.set_user_suspended_status(self.user1, True)) - - channel = self.make_request( - "PUT", - f"/_matrix/client/v3/profile/{self.user1}/avatar_url", - access_token=self.tok1, - content={"avatar_url": "mxc://matrix.org/wefh34uihSDRGhw34"}, - shorthand=False, - ) - self.assertEqual( - channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" - ) - - channel2 = self.make_request( - "PUT", - f"/_matrix/client/v3/profile/{self.user1}/displayname", - access_token=self.tok1, - content={"displayname": "something offensive"}, - shorthand=False, - ) - self.assertEqual( - channel2.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" - ) - - def test_suspended_user_cannot_redact_messages_other_than_their_own(self) -> None: - # first user sends message - self.make_request("POST", f"/rooms/{self.room1}/join", access_token=self.tok2) - res = self.helper.send_event( - self.room1, - "m.room.message", - {"body": "hello", "msgtype": "m.text"}, - tok=self.tok2, - ) - event_id = res["event_id"] - - # second user sends message - self.make_request("POST", f"/rooms/{self.room1}/join", access_token=self.tok1) - res2 = self.helper.send_event( - self.room1, - "m.room.message", - {"body": "bad_message", "msgtype": "m.text"}, - tok=self.tok1, - ) - event_id2 = res2["event_id"] - - # set the second user as suspended - self.get_success(self.store.set_user_suspended_status(self.user1, True)) - - # second user can't redact first user's message - channel = self.make_request( - "PUT", - f"/_matrix/client/v3/rooms/{self.room1}/redact/{event_id}/1", - access_token=self.tok1, - content={"reason": "bogus"}, - shorthand=False, - ) - self.assertEqual( - channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" - ) - - # but can redact their own - channel = self.make_request( - "PUT", - f"/_matrix/client/v3/rooms/{self.room1}/redact/{event_id2}/1", - access_token=self.tok1, - content={"reason": "bogus"}, - shorthand=False, - ) - self.assertEqual(channel.code, 200) diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 5ef501c6d5..0a7676e566 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -18,39 +17,15 @@ # [This file includes modifications made by New Vector Limited] # # -from parameterized import parameterized_class from synapse.api.constants import EduTypes from synapse.rest import admin from synapse.rest.client import login, sendtodevice, sync -from synapse.types import JsonDict from tests.unittest import HomeserverTestCase, override_config -@parameterized_class( - ("sync_endpoint", "experimental_features"), - [ - ("/sync", {}), - ( - "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee", - # Enable sliding sync - {"msc3575_enabled": True}, - ), - ], -) class SendToDeviceTestCase(HomeserverTestCase): - """ - Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`. - - Attributes: - sync_endpoint: The endpoint under test to use for syncing. - experimental_features: The experimental features homeserver config to use. - """ - - sync_endpoint: str - experimental_features: JsonDict - servlets = [ admin.register_servlets, login.register_servlets, @@ -58,11 +33,6 @@ class SendToDeviceTestCase(HomeserverTestCase): sync.register_servlets, ] - def default_config(self) -> JsonDict: - config = super().default_config() - config["experimental_features"] = self.experimental_features - return config - def test_user_to_user(self) -> None: """A to-device message from one user to another should get delivered""" @@ -83,7 +53,7 @@ class SendToDeviceTestCase(HomeserverTestCase): self.assertEqual(chan.code, 200, chan.result) # check it appears - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + channel = self.make_request("GET", "/sync", access_token=user2_tok) self.assertEqual(channel.code, 200, channel.result) expected_result = { "events": [ @@ -96,19 +66,15 @@ class SendToDeviceTestCase(HomeserverTestCase): } self.assertEqual(channel.json_body["to_device"], expected_result) - # it should re-appear if we do another sync because the to-device message is not - # deleted until we acknowledge it by sending a `?since=...` parameter in the - # next sync request corresponding to the `next_batch` value from the response. - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + # it should re-appear if we do another sync + channel = self.make_request("GET", "/sync", access_token=user2_tok) self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.json_body["to_device"], expected_result) # it should *not* appear if we do an incremental sync sync_token = channel.json_body["next_batch"] channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}", - access_token=user2_tok, + "GET", f"/sync?since={sync_token}", access_token=user2_tok ) self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) @@ -132,19 +98,15 @@ class SendToDeviceTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 200, chan.result) - # now sync: we should get two of the three (because burst_count=2) - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + # now sync: we should get two of the three + channel = self.make_request("GET", "/sync", access_token=user2_tok) self.assertEqual(channel.code, 200, channel.result) msgs = channel.json_body["to_device"]["events"] self.assertEqual(len(msgs), 2) for i in range(2): self.assertEqual( msgs[i], - { - "sender": user1, - "type": "m.room_key_request", - "content": {"idx": i}, - }, + {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}}, ) sync_token = channel.json_body["next_batch"] @@ -162,9 +124,7 @@ class SendToDeviceTestCase(HomeserverTestCase): # ... which should arrive channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}", - access_token=user2_tok, + "GET", f"/sync?since={sync_token}", access_token=user2_tok ) self.assertEqual(channel.code, 200, channel.result) msgs = channel.json_body["to_device"]["events"] @@ -198,7 +158,7 @@ class SendToDeviceTestCase(HomeserverTestCase): ) # now sync: we should get two of the three - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + channel = self.make_request("GET", "/sync", access_token=user2_tok) self.assertEqual(channel.code, 200, channel.result) msgs = channel.json_body["to_device"]["events"] self.assertEqual(len(msgs), 2) @@ -232,9 +192,7 @@ class SendToDeviceTestCase(HomeserverTestCase): # ... which should arrive channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}", - access_token=user2_tok, + "GET", f"/sync?since={sync_token}", access_token=user2_tok ) self.assertEqual(channel.code, 200, channel.result) msgs = channel.json_body["to_device"]["events"] @@ -258,7 +216,7 @@ class SendToDeviceTestCase(HomeserverTestCase): user2_tok = self.login("u2", "pass", "d2") # Do an initial sync - channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok) + channel = self.make_request("GET", "/sync", access_token=user2_tok) self.assertEqual(channel.code, 200, channel.result) sync_token = channel.json_body["next_batch"] @@ -274,9 +232,7 @@ class SendToDeviceTestCase(HomeserverTestCase): self.assertEqual(chan.code, 200, chan.result) channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}&timeout=300000", - access_token=user2_tok, + "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok ) self.assertEqual(channel.code, 200, channel.result) messages = channel.json_body.get("to_device", {}).get("events", []) @@ -284,9 +240,7 @@ class SendToDeviceTestCase(HomeserverTestCase): sync_token = channel.json_body["next_batch"] channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={sync_token}&timeout=300000", - access_token=user2_tok, + "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok ) self.assertEqual(channel.code, 200, channel.result) messages = channel.json_body.get("to_device", {}).get("events", []) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 2287f233b4..38eb2daf72 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 63df31ec75..d52f1cc34e 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -19,10 +18,9 @@ # # import json -import logging from typing import List -from parameterized import parameterized, parameterized_class +from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor @@ -44,8 +42,6 @@ from tests.federation.transport.test_knocking import ( ) from tests.server import TimedOutException -logger = logging.getLogger(__name__) - class FilterTestCase(unittest.HomeserverTestCase): user_id = "@apple:test" @@ -691,180 +687,24 @@ class SyncCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.json_body) -@parameterized_class( - ("sync_endpoint", "experimental_features"), - [ - ("/sync", {}), - ( - "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee", - # Enable sliding sync - {"msc3575_enabled": True}, - ), - ], -) class DeviceListSyncTestCase(unittest.HomeserverTestCase): - """ - Tests regarding device list (`device_lists`) changes. - - Attributes: - sync_endpoint: The endpoint under test to use for syncing. - experimental_features: The experimental features homeserver config to use. - """ - - sync_endpoint: str - experimental_features: JsonDict - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, - room.register_servlets, sync.register_servlets, devices.register_servlets, ] - def default_config(self) -> JsonDict: - config = super().default_config() - config["experimental_features"] = self.experimental_features - return config - - def test_receiving_local_device_list_changes(self) -> None: - """Tests that a local users that share a room receive each other's device list - changes. - """ - # Register two users - test_device_id = "TESTDEVICE" - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - bob_user_id = self.register_user("bob", "ponyponypony") - bob_access_token = self.login(bob_user_id, "ponyponypony") - - # Create a room for them to coexist peacefully in - new_room_id = self.helper.create_room_as( - alice_user_id, is_public=True, tok=alice_access_token - ) - self.assertIsNotNone(new_room_id) - - # Have Bob join the room - self.helper.invite( - new_room_id, alice_user_id, bob_user_id, tok=alice_access_token - ) - self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) - - # Now have Bob initiate an initial sync (in order to get a since token) - channel = self.make_request( - "GET", - self.sync_endpoint, - access_token=bob_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) - next_batch_token = channel.json_body["next_batch"] - - # ...and then an incremental sync. This should block until the sync stream is woken up, - # which we hope will happen as a result of Alice updating their device list. - bob_sync_channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000", - access_token=bob_access_token, - # Start the request, then continue on. - await_result=False, - ) - - # Have alice update their device list - channel = self.make_request( - "PUT", - f"/devices/{test_device_id}", - { - "display_name": "New Device Name", - }, - access_token=alice_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that bob's incremental sync contains the updated device list. - # If not, the client would only receive the device list update on the - # *next* sync. - bob_sync_channel.await_result() - self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) - - changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( - "changed", [] - ) - self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) - - def test_not_receiving_local_device_list_changes(self) -> None: - """Tests a local users DO NOT receive device updates from each other if they do not - share a room. - """ - # Register two users - test_device_id = "TESTDEVICE" - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - bob_user_id = self.register_user("bob", "ponyponypony") - bob_access_token = self.login(bob_user_id, "ponyponypony") - - # These users do not share a room. They are lonely. - - # Have Bob initiate an initial sync (in order to get a since token) - channel = self.make_request( - "GET", - self.sync_endpoint, - access_token=bob_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) - next_batch_token = channel.json_body["next_batch"] - - # ...and then an incremental sync. This should block until the sync stream is woken up, - # which we hope will happen as a result of Alice updating their device list. - bob_sync_channel = self.make_request( - "GET", - f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000", - access_token=bob_access_token, - # Start the request, then continue on. - await_result=False, - ) - - # Have alice update their device list - channel = self.make_request( - "PUT", - f"/devices/{test_device_id}", - { - "display_name": "New Device Name", - }, - access_token=alice_access_token, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that bob's incremental sync does not contain the updated device list. - bob_sync_channel.await_result() - self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) - - changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( - "changed", [] - ) - self.assertNotIn( - alice_user_id, changed_device_lists, bob_sync_channel.json_body - ) - def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None: """Tests that a user with no rooms still receives their own device list updates""" - test_device_id = "TESTDEVICE" + device_id = "TESTDEVICE" # Register a user and login, creating a device - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey", device_id=device_id) # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) + channel = self.make_request("GET", "/sync", access_token=self.tok) self.assertEqual(channel.code, 200, channel.json_body) next_batch = channel.json_body["next_batch"] @@ -872,19 +712,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): # It won't return until something has happened incremental_sync_channel = self.make_request( "GET", - f"{self.sync_endpoint}?since={next_batch}&timeout=30000", - access_token=alice_access_token, + f"/sync?since={next_batch}&timeout=30000", + access_token=self.tok, await_result=False, ) # Change our device's display name channel = self.make_request( "PUT", - f"devices/{test_device_id}", + f"devices/{device_id}", { "display_name": "freeze ray", }, - access_token=alice_access_token, + access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -898,229 +738,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): ).get("changed", []) self.assertIn( - alice_user_id, device_list_changes, incremental_sync_channel.json_body - ) - - -@parameterized_class( - ("sync_endpoint", "experimental_features"), - [ - ("/sync", {}), - ( - "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee", - # Enable sliding sync - {"msc3575_enabled": True}, - ), - ], -) -class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase): - """ - Tests regarding device one time keys (`device_one_time_keys_count`) changes. - - Attributes: - sync_endpoint: The endpoint under test to use for syncing. - experimental_features: The experimental features homeserver config to use. - """ - - sync_endpoint: str - experimental_features: JsonDict - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - sync.register_servlets, - devices.register_servlets, - ] - - def default_config(self) -> JsonDict: - config = super().default_config() - config["experimental_features"] = self.experimental_features - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.e2e_keys_handler = hs.get_e2e_keys_handler() - - def test_no_device_one_time_keys(self) -> None: - """ - Tests when no one time keys set, it still has the default `signed_curve25519` in - `device_one_time_keys_count` - """ - test_device_id = "TESTDEVICE" - - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check for those one time key counts - self.assertDictEqual( - channel.json_body["device_one_time_keys_count"], - # Note that "signed_curve25519" is always returned in key count responses - # regardless of whether we uploaded any keys for it. This is necessary until - # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. - {"signed_curve25519": 0}, - channel.json_body["device_one_time_keys_count"], - ) - - def test_returns_device_one_time_keys(self) -> None: - """ - Tests that one time keys for the device/user are counted correctly in the `/sync` - response - """ - test_device_id = "TESTDEVICE" - - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - # Upload one time keys for the user/device - keys: JsonDict = { - "alg1:k1": "key1", - "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, - "alg2:k3": {"key": "key3"}, - } - res = self.get_success( - self.e2e_keys_handler.upload_keys_for_user( - alice_user_id, test_device_id, {"one_time_keys": keys} - ) - ) - # Note that "signed_curve25519" is always returned in key count responses - # regardless of whether we uploaded any keys for it. This is necessary until - # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. - self.assertDictEqual( - res, - {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}, - ) - - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check for those one time key counts - self.assertDictEqual( - channel.json_body["device_one_time_keys_count"], - {"alg1": 1, "alg2": 2, "signed_curve25519": 0}, - channel.json_body["device_one_time_keys_count"], - ) - - -@parameterized_class( - ("sync_endpoint", "experimental_features"), - [ - ("/sync", {}), - ( - "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee", - # Enable sliding sync - {"msc3575_enabled": True}, - ), - ], -) -class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase): - """ - Tests regarding device one time keys (`device_unused_fallback_key_types`) changes. - - Attributes: - sync_endpoint: The endpoint under test to use for syncing. - experimental_features: The experimental features homeserver config to use. - """ - - sync_endpoint: str - experimental_features: JsonDict - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - sync.register_servlets, - devices.register_servlets, - ] - - def default_config(self) -> JsonDict: - config = super().default_config() - config["experimental_features"] = self.experimental_features - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = self.hs.get_datastores().main - self.e2e_keys_handler = hs.get_e2e_keys_handler() - - def test_no_device_unused_fallback_key(self) -> None: - """ - Test when no unused fallback key is set, it just returns an empty list. The MSC - says "The device_unused_fallback_key_types parameter must be present if the - server supports fallback keys.", - https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md - """ - test_device_id = "TESTDEVICE" - - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check for those one time key counts - self.assertListEqual( - channel.json_body["device_unused_fallback_key_types"], - [], - channel.json_body["device_unused_fallback_key_types"], - ) - - def test_returns_device_one_time_keys(self) -> None: - """ - Tests that device unused fallback key type is returned correctly in the `/sync` - """ - test_device_id = "TESTDEVICE" - - alice_user_id = self.register_user("alice", "correcthorse") - alice_access_token = self.login( - alice_user_id, "correcthorse", device_id=test_device_id - ) - - # We shouldn't have any unused fallback keys yet - res = self.get_success( - self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id) - ) - self.assertEqual(res, []) - - # Upload a fallback key for the user/device - self.get_success( - self.e2e_keys_handler.upload_keys_for_user( - alice_user_id, - test_device_id, - {"fallback_keys": {"alg1:k1": "fallback_key1"}}, - ) - ) - # We should now have an unused alg1 key - fallback_res = self.get_success( - self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id) - ) - self.assertEqual(fallback_res, ["alg1"], fallback_res) - - # Request an initial sync - channel = self.make_request( - "GET", self.sync_endpoint, access_token=alice_access_token - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check for the unused fallback key types - self.assertListEqual( - channel.json_body["device_unused_fallback_key_types"], - ["alg1"], - channel.json_body["device_unused_fallback_key_types"], + self.user_id, device_list_changes, incremental_sync_channel.json_body ) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index d10df1a90f..22b2e34570 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index af1eecbb34..d7f479786d 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index 805c49b540..1a83e70f48 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index c4b15c5ae7..7038e42058 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index e43140720d..7bcbed246b 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py
@@ -1,9 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. -# Copyright 2017 Vector Creations Ltd -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -87,7 +84,8 @@ class RestHelper: expect_code: Literal[200] = ..., extra_content: Optional[Dict] = ..., custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., - ) -> str: ... + ) -> str: + ... @overload def create_room_as( @@ -99,7 +97,8 @@ class RestHelper: expect_code: int = ..., extra_content: Optional[Dict] = ..., custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., - ) -> Optional[str]: ... + ) -> Optional[str]: + ... def create_room_as( self, @@ -170,16 +169,14 @@ class RestHelper: targ: Optional[str] = None, expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, - extra_data: Optional[dict] = None, - ) -> JsonDict: - return self.change_membership( + ) -> None: + self.change_membership( room=room, src=src, targ=targ, tok=tok, membership=Membership.INVITE, expect_code=expect_code, - extra_data=extra_data, ) def join( @@ -191,8 +188,8 @@ class RestHelper: appservice_user_id: Optional[str] = None, expect_errcode: Optional[Codes] = None, expect_additional_fields: Optional[dict] = None, - ) -> JsonDict: - return self.change_membership( + ) -> None: + self.change_membership( room=room, src=user, targ=user, @@ -244,8 +241,8 @@ class RestHelper: user: Optional[str] = None, expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, - ) -> JsonDict: - return self.change_membership( + ) -> None: + self.change_membership( room=room, src=user, targ=user, @@ -261,9 +258,9 @@ class RestHelper: targ: str, expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, - ) -> JsonDict: + ) -> None: """A convenience helper: `change_membership` with `membership` preset to "ban".""" - return self.change_membership( + self.change_membership( room=room, src=src, targ=targ, @@ -284,7 +281,7 @@ class RestHelper: expect_code: int = HTTPStatus.OK, expect_errcode: Optional[str] = None, expect_additional_fields: Optional[dict] = None, - ) -> JsonDict: + ) -> None: """ Send a membership state event into a room. @@ -300,9 +297,6 @@ class RestHelper: using an application service access token in `tok`. expect_code: The expected HTTP response code expect_errcode: The expected Matrix error code - - Returns: - The JSON response """ temp_id = self.auth_user_id self.auth_user_id = src @@ -330,12 +324,9 @@ class RestHelper: data, ) - assert ( - channel.code == expect_code - ), "Expected: %d, got: %d, PUT %s -> resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, channel.code, - path, channel.result["body"], ) @@ -364,7 +355,6 @@ class RestHelper: ) self.auth_user_id = temp_id - return channel.json_body def send( self, @@ -374,7 +364,6 @@ class RestHelper: tok: Optional[str] = None, expect_code: int = HTTPStatus.OK, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - type: str = "m.room.message", ) -> JsonDict: if body is None: body = "body_text_here" @@ -383,7 +372,7 @@ class RestHelper: return self.send_event( room_id, - type, + "m.room.message", content, txn_id, tok, diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 21e12b2a2f..c404b85ec6 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/media/test_domain_blocking.py b/tests/rest/media/test_domain_blocking.py
index 72205c6bb3..b0fbdbfcf4 100644 --- a/tests/rest/media/test_domain_blocking.py +++ b/tests/rest/media/test_domain_blocking.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -44,13 +43,13 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase): # from a regular 404. file_id = "abcdefg12345" file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) - - media_storage = hs.get_media_repository().media_storage - - ctx = media_storage.store_into_file(file_info) - (f, fname) = self.get_success(ctx.__aenter__()) - f.write(SMALL_PNG) - self.get_success(ctx.__aexit__(None, None, None)) + with hs.get_media_repository().media_storage.store_into_file(file_info) as ( + f, + fname, + finish, + ): + f.write(SMALL_PNG) + self.get_success(finish()) self.get_success( self.store.store_cached_remote_media( diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py
index a96f0e7fca..0b952e185d 100644 --- a/tests/rest/media/test_url_preview.py +++ b/tests/rest/media/test_url_preview.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/rest/synapse/__init__.py b/tests/rest/synapse/__init__.py deleted file mode 100644
index e5138f67e1..0000000000 --- a/tests/rest/synapse/__init__.py +++ /dev/null
@@ -1,12 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. diff --git a/tests/rest/synapse/client/__init__.py b/tests/rest/synapse/client/__init__.py deleted file mode 100644
index e5138f67e1..0000000000 --- a/tests/rest/synapse/client/__init__.py +++ /dev/null
@@ -1,12 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. diff --git a/tests/rest/synapse/client/test_federation_whitelist.py b/tests/rest/synapse/client/test_federation_whitelist.py deleted file mode 100644
index f0067a8f2b..0000000000 --- a/tests/rest/synapse/client/test_federation_whitelist.py +++ /dev/null
@@ -1,119 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2024 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. - -from typing import Dict - -from twisted.web.resource import Resource - -from synapse.rest import admin -from synapse.rest.client import login -from synapse.rest.synapse.client import build_synapse_client_resource_tree - -from tests import unittest - - -class FederationWhitelistTests(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - def create_resource_dict(self) -> Dict[str, Resource]: - base = super().create_resource_dict() - base.update(build_synapse_client_resource_tree(self.hs)) - return base - - def test_default(self) -> None: - "If the config option is not enabled, the endpoint should 404" - channel = self.make_request( - "GET", "/_synapse/client/v1/config/federation_whitelist", shorthand=False - ) - - self.assertEqual(channel.code, 404) - - @unittest.override_config({"federation_whitelist_endpoint_enabled": True}) - def test_no_auth(self) -> None: - "Endpoint requires auth when enabled" - - channel = self.make_request( - "GET", "/_synapse/client/v1/config/federation_whitelist", shorthand=False - ) - - self.assertEqual(channel.code, 401) - - @unittest.override_config({"federation_whitelist_endpoint_enabled": True}) - def test_no_whitelist(self) -> None: - "Test when there is no whitelist configured" - - self.register_user("user", "password") - tok = self.login("user", "password") - - channel = self.make_request( - "GET", - "/_synapse/client/v1/config/federation_whitelist", - shorthand=False, - access_token=tok, - ) - - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, {"whitelist_enabled": False, "whitelist": []} - ) - - @unittest.override_config( - { - "federation_whitelist_endpoint_enabled": True, - "federation_domain_whitelist": ["example.com"], - } - ) - def test_whitelist(self) -> None: - "Test when there is a whitelist configured" - - self.register_user("user", "password") - tok = self.login("user", "password") - - channel = self.make_request( - "GET", - "/_synapse/client/v1/config/federation_whitelist", - shorthand=False, - access_token=tok, - ) - - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, {"whitelist_enabled": True, "whitelist": ["example.com"]} - ) - - @unittest.override_config( - { - "federation_whitelist_endpoint_enabled": True, - "federation_domain_whitelist": ["example.com", "example.com"], - } - ) - def test_whitelist_no_duplicates(self) -> None: - "Test when there is a whitelist configured with duplicates, no duplicates are returned" - - self.register_user("user", "password") - tok = self.login("user", "password") - - channel = self.make_request( - "GET", - "/_synapse/client/v1/config/federation_whitelist", - shorthand=False, - access_token=tok, - ) - - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, {"whitelist_enabled": True, "whitelist": ["example.com"]} - ) diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index bdbfce796a..5c6fd57df6 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/server.py b/tests/server.py
index 3e377585ce..f94d0e4397 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -47,7 +46,7 @@ from typing import ( Union, cast, ) -from unittest.mock import Mock, patch +from unittest.mock import Mock import attr from incremental import Version @@ -55,7 +54,6 @@ from typing_extensions import ParamSpec from zope.interface import implementer import twisted -from twisted.enterprise import adbapi from twisted.internet import address, tcp, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed @@ -85,7 +83,6 @@ from twisted.web.server import Request, Site from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig -from synapse.events.auto_accept_invites import InviteAutoAccepter from synapse.events.presence_router import load_legacy_presence_router from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.http.site import SynapseRequest @@ -96,8 +93,8 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( ) from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.database import LoggingDatabaseConnection, make_pool -from synapse.storage.engines import BaseDatabaseEngine, create_engine +from synapse.storage.database import LoggingDatabaseConnection +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.types import ISynapseReactor, JsonDict from synapse.util import Clock @@ -198,35 +195,17 @@ class FakeChannel: def headers(self) -> Headers: if not self.result: raise Exception("No result yet.") - - h = self.result["headers"] - assert isinstance(h, Headers) + h = Headers() + for i in self.result["headers"]: + h.addRawHeader(*i) return h def writeHeaders( - self, - version: bytes, - code: bytes, - reason: bytes, - headers: Union[Headers, List[Tuple[bytes, bytes]]], + self, version: bytes, code: bytes, reason: bytes, headers: Headers ) -> None: self.result["version"] = version self.result["code"] = code self.result["reason"] = reason - - if isinstance(headers, list): - # Support prior to Twisted 24.7.0rc1 - new_headers = Headers() - for k, v in headers: - assert isinstance(k, bytes), f"key is not of type bytes: {k!r}" - assert isinstance(v, bytes), f"value is not of type bytes: {v!r}" - new_headers.addRawHeader(k, v) - headers = new_headers - - assert isinstance( - headers, Headers - ), f"headers are of the wrong type: {headers!r}" - self.result["headers"] = headers def write(self, data: bytes) -> None: @@ -307,6 +286,10 @@ class FakeChannel: self._reactor.run() while not self.is_finished(): + # If there's a producer, tell it to resume producing so we get content + if self._producer: + self._producer.resumeProducing() + if self._reactor.seconds() > end_time: raise TimedOutException("Timed out waiting for request to finish.") @@ -366,7 +349,6 @@ def make_request( request: Type[Request] = SynapseRequest, shorthand: bool = True, federation_auth_origin: Optional[bytes] = None, - content_type: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, custom_headers: Optional[Iterable[CustomHeaderType]] = None, @@ -389,8 +371,6 @@ def make_request( with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake Authorization header pretenting to be the given server name. - content_type: The content-type to use for the request. If not set then will default to - application/json unless content_is_form is true. content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. await_result: whether to wait for the request to complete rendering. If true, @@ -454,9 +434,7 @@ def make_request( ) if content: - if content_type is not None: - req.requestHeaders.addRawHeader(b"Content-Type", content_type) - elif content_is_form: + if content_is_form: req.requestHeaders.addRawHeader( b"Content-Type", b"application/x-www-form-urlencoded" ) @@ -691,53 +669,6 @@ def validate_connector(connector: tcp.Connector, expected_ip: str) -> None: ) -def make_fake_db_pool( - reactor: ISynapseReactor, - db_config: DatabaseConnectionConfig, - engine: BaseDatabaseEngine, -) -> adbapi.ConnectionPool: - """Wrapper for `make_pool` which builds a pool which runs db queries synchronously. - - For more deterministic testing, we don't use a regular db connection pool: instead - we run all db queries synchronously on the test reactor's main thread. This function - is a drop-in replacement for the normal `make_pool` which builds such a connection - pool. - """ - pool = make_pool(reactor, db_config, engine) - - def runWithConnection( - func: Callable[..., R], *args: Any, **kwargs: Any - ) -> Awaitable[R]: - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runWithConnection, - func, - *args, - **kwargs, - ) - - def runInteraction( - desc: str, func: Callable[..., R], *args: Any, **kwargs: Any - ) -> Awaitable[R]: - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runInteraction, - desc, - func, - *args, - **kwargs, - ) - - pool.runWithConnection = runWithConnection # type: ignore[method-assign] - pool.runInteraction = runInteraction # type: ignore[assignment] - # Replace the thread pool with a threadless 'thread' pool - pool.threadpool = ThreadPool(reactor) - pool.running = True - return pool - - class ThreadPool: """ Threadless thread pool. @@ -774,6 +705,52 @@ class ThreadPool: return d +def _make_test_homeserver_synchronous(server: HomeServer) -> None: + """ + Make the given test homeserver's database interactions synchronous. + """ + + clock = server.get_clock() + + for database in server.get_datastores().databases: + pool = database._db_pool + + def runWithConnection( + func: Callable[..., R], *args: Any, **kwargs: Any + ) -> Awaitable[R]: + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs, + ) + + def runInteraction( + desc: str, func: Callable[..., R], *args: Any, **kwargs: Any + ) -> Awaitable[R]: + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + desc, + func, + *args, + **kwargs, + ) + + pool.runWithConnection = runWithConnection # type: ignore[method-assign] + pool.runInteraction = runInteraction # type: ignore[assignment] + # Replace the thread pool with a threadless 'thread' pool + pool.threadpool = ThreadPool(clock._reactor) + pool.running = True + + # We've just changed the Databases to run DB transactions on the same + # thread, so we need to disable the dedicated thread behaviour. + server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False + + def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: clock = ThreadedMemoryReactorClock() hs_clock = Clock(clock) @@ -960,7 +937,7 @@ def connect_client( class TestHomeServer(HomeServer): - DATASTORE_CLASS = DataStore + DATASTORE_CLASS = DataStore # type: ignore[assignment] def setup_test_homeserver( @@ -1089,14 +1066,7 @@ def setup_test_homeserver( # Mock TLS hs.tls_server_context_factory = Mock() - # Patch `make_pool` before initialising the database, to make database transactions - # synchronous for testing. - with patch("synapse.storage.database.make_pool", side_effect=make_fake_db_pool): - hs.setup() - - # Since we've changed the databases to run DB transactions on the same - # thread, we need to stop the event fetcher hogging that one thread. - hs.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False + hs.setup() if USE_POSTGRES_FOR_TESTS: database_pool = hs.get_datastores().databases[0] @@ -1166,16 +1136,14 @@ def setup_test_homeserver( hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment] + # Make the threadpool and database transactions synchronous for testing. + _make_test_homeserver_synchronous(hs) + # Load any configured modules into the homeserver module_api = hs.get_module_api() for module, module_config in hs.config.modules.loaded_modules: module(config=module_config, api=module_api) - if hs.config.auto_accept_invites.enabled: - # Start the local auto_accept_invites module. - m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api) - logger.info("Loaded local module %s", m) - load_legacy_spam_checkers(hs) load_legacy_third_party_event_rules(hs) load_legacy_presence_router(hs) diff --git a/tests/storage/databases/__init__.py b/tests/storage/databases/__init__.py
index 7e89a998b5..3d833a2e44 100644 --- a/tests/storage/databases/__init__.py +++ b/tests/storage/databases/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/databases/main/__init__.py b/tests/storage/databases/main/__init__.py
index 7e89a998b5..3d833a2e44 100644 --- a/tests/storage/databases/main/__init__.py +++ b/tests/storage/databases/main/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/databases/main/test_cache.py b/tests/storage/databases/main/test_cache.py
index 3900a35290..e493588b5d 100644 --- a/tests/storage/databases/main/test_cache.py +++ b/tests/storage/databases/main/test_cache.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
index f89f96fdcb..bb1daeacb7 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/databases/main/test_end_to_end_keys.py b/tests/storage/databases/main/test_end_to_end_keys.py
index 1ed1d01cea..fd6b3e9f98 100644 --- a/tests/storage/databases/main/test_end_to_end_keys.py +++ b/tests/storage/databases/main/test_end_to_end_keys.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index fd1f5e7fd5..caa5752032 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 1df71e723e..66c8864b1a 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
index da2ec26421..e000394db2 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 88a5aa8cb1..232aad5b79 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 506d981ce6..b3dc4fe848 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 2859bcf4bd..42fbc4db60 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 10533f45d7..baa55ca862 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index b28db6a4ad..a83dfd1717 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 9420d03841..d0176309cd 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index d5b9996284..058c6caf90 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -337,15 +336,15 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() expires old entries correctly. """ - self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["1"] = ( - 100000 - ) - self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["2"] = ( - 200000 - ) - self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["3"] = ( - 300000 - ) + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "1" + ] = 100000 + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "2" + ] = 200000 + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "3" + ] = 300000 self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() # All entries within time frame diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 13f78ee2d2..91c76e82ff 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 8af0d6265b..550ca26412 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index ba01b038ab..cb8cb06c84 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -36,14 +35,6 @@ class DeviceStoreTestCase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - def default_config(self) -> JsonDict: - config = super().default_config() - - # We 'enable' federation otherwise `get_device_updates_by_remote` will - # throw an exception. - config["federation_sender_instances"] = ["master"] - return config - def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None: """Add a device list change for the given device to `device_lists_outbound_pokes` table. diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index f1602fdc86..7be9ee102a 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 931d37e85a..40d2f3d88c 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index bd594d3c1f..44581a4cc0 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c4e216c308..e3d592e542 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -21,8 +20,6 @@ from typing import Dict, List, Set, Tuple, cast -from parameterized import parameterized - from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest @@ -47,8 +44,7 @@ class EventChainStoreTestCase(HomeserverTestCase): self.store = hs.get_datastores().main self._next_stream_ordering = 1 - @parameterized.expand([(False,), (True,)]) - def test_simple(self, batched: bool) -> None: + def test_simple(self) -> None: """Test that the example in `docs/auth_chain_difference_algorithm.md` works. """ @@ -56,7 +52,6 @@ class EventChainStoreTestCase(HomeserverTestCase): event_factory = self.hs.get_event_builder_factory() bob = "@creator:test" alice = "@alice:test" - charlie = "@charlie:test" room_id = "!room:test" # Ensure that we have a rooms entry so that we generate the chain index. @@ -195,26 +190,6 @@ class EventChainStoreTestCase(HomeserverTestCase): ) ) - charlie_invite = self.get_success( - event_factory.for_room_version( - RoomVersions.V6, - { - "type": EventTypes.Member, - "state_key": charlie, - "sender": alice, - "room_id": room_id, - "content": {"tag": "charlie_invite"}, - }, - ).build( - prev_event_ids=[], - auth_event_ids=[ - create.event_id, - alice_join2.event_id, - power_2.event_id, - ], - ) - ) - events = [ create, bob_join, @@ -224,41 +199,33 @@ class EventChainStoreTestCase(HomeserverTestCase): bob_join_2, power_2, alice_join2, - charlie_invite, ] expected_links = [ (bob_join, create), + (power, create), (power, bob_join), + (alice_invite, create), (alice_invite, power), + (alice_invite, bob_join), (bob_join_2, power), (alice_join2, power_2), - (charlie_invite, alice_join2), ] - # We either persist as a batch or one-by-one depending on test - # parameter. - if batched: - self.persist(events) - else: - for event in events: - self.persist([event]) - + self.persist(events) chain_map, link_map = self.fetch_chains(events) # Check that the expected links and only the expected links have been # added. - event_map = {e.event_id: e for e in events} - reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()} + self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) - self.maxDiff = None - self.assertCountEqual( - expected_links, - [ - (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)]) - for s1, s2, t1, t2 in link_map.get_additions() - ], - ) + for start, end in expected_links: + start_id, start_seq = chain_map[start.event_id] + end_id, end_seq = chain_map[end.event_id] + + self.assertIn( + (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) + ) # Test that everything can reach the create event, but the create event # can't reach anything. @@ -400,23 +367,24 @@ class EventChainStoreTestCase(HomeserverTestCase): expected_links = [ (bob_join, create), + (power, create), (power, bob_join), + (alice_invite, create), (alice_invite, power), + (alice_invite, bob_join), ] # Check that the expected links and only the expected links have been # added. - event_map = {e.event_id: e for e in events} - reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()} + self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) - self.maxDiff = None - self.assertCountEqual( - expected_links, - [ - (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)]) - for s1, s2, t1, t2 in link_map.get_additions() - ], - ) + for start, end in expected_links: + start_id, start_seq = chain_map[start.event_id] + end_id, end_seq = chain_map[end.event_id] + + self.assertIn( + (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) + ) def persist( self, @@ -431,7 +399,6 @@ class EventChainStoreTestCase(HomeserverTestCase): for e in events: e.internal_metadata.stream_ordering = self._next_stream_ordering - e.internal_metadata.instance_name = self.hs.get_instance_name() self._next_stream_ordering += 1 def _persist(txn: LoggingTransaction) -> None: @@ -447,14 +414,7 @@ class EventChainStoreTestCase(HomeserverTestCase): ) # Actually call the function that calculates the auth chain stuff. - new_event_links = ( - persist_events_store.calculate_chain_cover_index_for_events_txn( - txn, events[0].room_id, [e for e in events if e.is_state()] - ) - ) - persist_events_store._persist_event_auth_chain_txn( - txn, events, new_event_links - ) + persist_events_store._persist_event_auth_chain_txn(txn, events) self.get_success( persist_events_store.db_pool.runInteraction( @@ -528,6 +488,8 @@ class LinkMapTestCase(unittest.TestCase): link_map = _LinkMap() link_map.add_link((1, 1), (2, 1), new=False) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) + self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)]) self.assertCountEqual(link_map.get_additions(), []) self.assertTrue(link_map.exists_path_from((1, 5), (2, 1))) self.assertFalse(link_map.exists_path_from((1, 5), (2, 2))) @@ -536,30 +498,17 @@ class LinkMapTestCase(unittest.TestCase): # Attempting to add a redundant link is ignored. self.assertFalse(link_map.add_link((1, 4), (2, 1))) - self.assertCountEqual(link_map.get_additions(), []) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) # Adding new non-redundant links works self.assertTrue(link_map.add_link((1, 3), (2, 3))) - self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3)]) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) self.assertTrue(link_map.add_link((2, 5), (1, 3))) - self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) + self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)]) + self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) - def test_exists_path_from(self) -> None: - "Check that `exists_path_from` can handle non-direct links" - link_map = _LinkMap() - - link_map.add_link((1, 1), (2, 1), new=False) - link_map.add_link((2, 1), (3, 1), new=False) - - self.assertTrue(link_map.exists_path_from((1, 4), (3, 1))) - self.assertFalse(link_map.exists_path_from((1, 4), (3, 2))) - - link_map.add_link((1, 5), (2, 3), new=False) - link_map.add_link((2, 2), (3, 3), new=False) - - self.assertTrue(link_map.exists_path_from((1, 6), (3, 2))) - self.assertFalse(link_map.exists_path_from((1, 4), (3, 2))) + self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) class EventChainBackgroundUpdateTestCase(HomeserverTestCase): diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 088f0d24f9..0a6253e22c 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py
@@ -365,19 +365,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): }, ) - events = [ - cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id])) - for event_id in AUTH_GRAPH - ] - new_event_links = ( - self.persist_events.calculate_chain_cover_index_for_events_txn( - txn, room_id, [e for e in events if e.is_state()] - ) - ) self.persist_events._persist_event_auth_chain_txn( txn, - events, - new_event_links, + [ + cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id])) + for event_id in AUTH_GRAPH + ], ) self.get_success( @@ -551,9 +544,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): rooms. """ - # We allow partial covers for this test - self.hs.get_datastores().main.tests_allow_no_chain_cover_index = True - room_id = "@ROOM:local" # The silly auth graph we use to test the auth difference algorithm, @@ -638,20 +628,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) # Insert all events apart from 'B' - events = [ - cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) - for event_id in auth_graph - if event_id != "b" - ] - new_event_links = ( - self.persist_events.calculate_chain_cover_index_for_events_txn( - txn, room_id, [e for e in events if e.is_state()] - ) - ) self.persist_events._persist_event_auth_chain_txn( txn, - events, - new_event_links, + [ + cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) + for event_id in auth_graph + if event_id != "b" + ], ) # Now we insert the event 'B' without a chain cover, by temporarily @@ -664,14 +647,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): updatevalues={"has_auth_chain_index": False}, ) - events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))] - new_event_links = ( - self.persist_events.calculate_chain_cover_index_for_events_txn( - txn, room_id, [e for e in events if e.is_state()] - ) - ) self.persist_events._persist_event_auth_chain_txn( - txn, events, new_event_links + txn, + [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))], ) self.store.db_pool.simple_update_txn( diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 3f7ee86498..02f7e6cd39 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 233066bf82..9639d49d5f 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 0a7c4c9421..eb37700192 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 12b89cecb6..aef52f131e 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -18,7 +17,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import Dict, List, Optional +from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor @@ -28,55 +27,177 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.storage.util.sequence import ( - LocalSequenceGenerator, - PostgresSequenceGenerator, - SequenceGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util import Clock from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS -class MultiWriterIdGeneratorBase(HomeserverTestCase): - positive: bool = True - tables: List[str] = ["foobar"] - +class StreamIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool - self.instances: Dict[str, MultiWriterIdGenerator] = {} self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - if USE_POSTGRES_FOR_TESTS: - self.seq_gen: SequenceGenerator = PostgresSequenceGenerator("foobar_seq") - else: - self.seq_gen = LocalSequenceGenerator(lambda _: 0) - def _setup_db(self, txn: LoggingTransaction) -> None: - if USE_POSTGRES_FOR_TESTS: - txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + data TEXT + ); + """ + ) + txn.execute("INSERT INTO foobar VALUES (123, 'hello world');") - for table in self.tables: - txn.execute( - """ - CREATE TABLE %s ( - stream_id BIGINT NOT NULL, - instance_name TEXT NOT NULL, - data TEXT - ); - """ - % (table,) + def _create_id_generator(self) -> StreamIdGenerator: + def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: + return StreamIdGenerator( + db_conn=conn, + notifier=self.hs.get_replication_notifier(), + table="foobar", + column="stream_id", ) + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def test_initial_value(self) -> None: + """Check that we read the current token from the DB.""" + id_gen = self._create_id_generator() + self.assertEqual(id_gen.get_current_token(), 123) + + def test_single_gen_next(self) -> None: + """Check that we correctly increment the current token from the DB.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + async with id_gen.get_next() as next_id: + # We haven't persisted `next_id` yet; current token is still 123 + self.assertEqual(id_gen.get_current_token(), 123) + # But we did learn what the next value is + self.assertEqual(next_id, 124) + + # Once the context manager closes we assume that the `next_id` has been + # written to the DB. + self.assertEqual(id_gen.get_current_token(), 124) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts(self) -> None: + """Check that we handle overlapping calls to gen_next sensibly.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist each in turn. + await ctx1.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 124) + await ctx2.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 125) + await ctx3.__aexit__(None, None, None) + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_multiple_gen_nexts_closed_in_different_order(self) -> None: + """Check that we handle overlapping calls to gen_next, even when their IDs + created and persisted in different orders.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request three new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + self.assertEqual(await ctx3.__aenter__(), 126) + + # None are persisted: current token unchanged. + self.assertEqual(id_gen.get_current_token(), 123) + + # Persist them in a different order, starting with 126 from ctx3. + await ctx3.__aexit__(None, None, None) + # We haven't persisted 124 from ctx1 yet---current token is still 123. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now persist 124 from ctx1. + await ctx1.__aexit__(None, None, None) + # Current token is then 124, waiting for 125 to be persisted. + self.assertEqual(id_gen.get_current_token(), 124) + + # Finally persist 125 from ctx2. + await ctx2.__aexit__(None, None, None) + # Current token is then 126 (skipping over 125). + self.assertEqual(id_gen.get_current_token(), 126) + + self.get_success(test_gen_next()) + + def test_gen_next_while_still_waiting_for_persistence(self) -> None: + """Check that we handle overlapping calls to gen_next.""" + id_gen = self._create_id_generator() + + async def test_gen_next() -> None: + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + + # Request two new stream IDs. + self.assertEqual(await ctx1.__aenter__(), 124) + self.assertEqual(await ctx2.__aenter__(), 125) + + # Persist ctx2 first. + await ctx2.__aexit__(None, None, None) + # Still waiting on ctx1's ID to be persisted. + self.assertEqual(id_gen.get_current_token(), 123) + + # Now request a third stream ID. It should be 126 (the smallest ID that + # we've not yet handed out.) + self.assertEqual(await ctx3.__aenter__(), 126) + + self.get_success(test_gen_next()) + + +class MultiWriterIdGeneratorTestCase(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn: LoggingTransaction) -> None: + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + def _create_id_generator( - self, - instance_name: str = "master", - writers: Optional[List[str]] = None, + self, instance_name: str = "master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( @@ -85,98 +206,58 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase): notifier=self.hs.get_replication_notifier(), stream_name="test_stream", instance_name=instance_name, - tables=[(table, "instance_name", "stream_id") for table in self.tables], + tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", writers=writers or ["master"], - positive=self.positive, ) - self.instances[instance_name] = self.get_success_or_raise( - self.db_pool.runWithConnection(_create) - ) - return self.instances[instance_name] - - def _replicate(self, instance_name: str) -> None: - """Similate a replication event for the given instance.""" - - writer = self.instances[instance_name] - token = writer.get_current_token_for_writer(instance_name) - for generator in self.instances.values(): - if writer != generator: - generator.advance(instance_name, token) + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) - def _replicate_all(self) -> None: - """Similate a replication event for all instances.""" - - for instance_name in self.instances: - self._replicate(instance_name) + def _insert_rows(self, instance_name: str, number: int) -> None: + """Insert N rows as the given instance, inserting with stream IDs pulled + from the postgres sequence. + """ - def _insert_row( - self, instance_name: str, stream_id: int, table: Optional[str] = None - ) -> None: - """Insert one row as the given instance with given stream_id.""" + def _insert(txn: LoggingTransaction) -> None: + for _ in range(number): + txn.execute( + "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", + (instance_name,), + ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + """, + (instance_name,), + ) - if table is None: - table = self.tables[0] + self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) - factor = 1 if self.positive else -1 + def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None: + """Insert one row as the given instance with given stream_id, updating + the postgres sequence position to match. + """ def _insert(txn: LoggingTransaction) -> None: txn.execute( - "INSERT INTO %s VALUES (?, ?)" % (table,), + "INSERT INTO foobar VALUES (?, ?)", ( stream_id, instance_name, ), ) + txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) txn.execute( """ INSERT INTO stream_positions VALUES ('test_stream', ?, ?) ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? """, - (instance_name, stream_id * factor, stream_id * factor), + (instance_name, stream_id, stream_id), ) - self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) - - def _insert_rows( - self, - instance_name: str, - number: int, - table: Optional[str] = None, - update_stream_table: bool = True, - ) -> None: - """Insert N rows as the given instance, inserting with stream IDs pulled - from the postgres sequence. - """ - - if table is None: - table = self.tables[0] - - factor = 1 if self.positive else -1 - - def _insert(txn: LoggingTransaction) -> None: - for _ in range(number): - next_val = self.seq_gen.get_next_id_txn(txn) - txn.execute( - "INSERT INTO %s (stream_id, instance_name) VALUES (?, ?)" - % (table,), - (next_val, instance_name), - ) - - if update_stream_table: - txn.execute( - """ - INSERT INTO stream_positions VALUES ('test_stream', ?, ?) - ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? - """, - (instance_name, next_val * factor, next_val * factor), - ) - - self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) - + self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) -class MultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): def test_empty(self) -> None: """Test an ID generator against an empty database gives sensible current positions. @@ -265,106 +346,137 @@ class MultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) - def test_get_next_txn(self) -> None: - """Test that the `get_next_txn` function works correctly.""" + def test_multi_instance(self) -> None: + """Test that reads and writes from multiple processes are handled + correctly. + """ + self._insert_rows("first", 3) + self._insert_rows("second", 4) - # Prefill table with 7 rows written by 'master' - self._insert_rows("master", 7) + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) - id_gen = self._create_id_generator() + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) - self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - def _get_next_txn(txn: LoggingTransaction) -> None: - stream_id = id_gen.get_next_txn(txn) - self.assertEqual(stream_id, 8) + async def _get_next_async() -> None: + async with first_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) - self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + self.assertEqual( + first_id_gen.get_positions(), {"first": 3, "second": 7} + ) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) - self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) + self.get_success(_get_next_async()) - self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) + self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7}) - def test_restart_during_out_of_order_persistence(self) -> None: - """Test that restarting a process while another process is writing out - of order updates are handled correctly. - """ + # However the ID gen on the second instance won't have seen the update + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) - # Prefill table with 7 rows written by 'master' - self._insert_rows("master", 7) + # ... but calling `get_next` on the second instance should give a unique + # stream ID - id_gen = self._create_id_generator() + async def _get_next_async2() -> None: + async with second_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 9) - self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + self.assertEqual( + second_id_gen.get_positions(), {"first": 3, "second": 7} + ) - # Persist two rows at once - ctx1 = id_gen.get_next() - ctx2 = id_gen.get_next() + self.get_success(_get_next_async2()) - s1 = self.get_success(ctx1.__aenter__()) - s2 = self.get_success(ctx2.__aenter__()) + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) - self.assertEqual(s1, 8) - self.assertEqual(s2, 9) + # If the second ID gen gets told about the first, it correctly updates + second_id_gen.advance("first", 8) + self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) - self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + def test_multi_instance_empty_row(self) -> None: + """Test that reads and writes from multiple processes are handled + correctly, when one of the writers starts without any rows. + """ + # Insert some rows for two out of three of the ID gens. + self._insert_rows("first", 3) + self._insert_rows("second", 4) - # We finish persisting the second row before restart - self.get_success(ctx2.__aexit__(None, None, None)) + first_id_gen = self._create_id_generator( + "first", writers=["first", "second", "third"] + ) + second_id_gen = self._create_id_generator( + "second", writers=["first", "second", "third"] + ) + third_id_gen = self._create_id_generator( + "third", writers=["first", "second", "third"] + ) - # We simulate a restart of another worker by just creating a new ID gen. - id_gen_worker = self._create_id_generator("worker") + self.assertEqual( + first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} + ) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7) - # Restarted worker should not see the second persisted row - self.assertEqual(id_gen_worker.get_positions(), {"master": 7}) - self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7) + self.assertEqual( + second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} + ) + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) + self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7) - # Now if we persist the first row then both instances should jump ahead - # correctly. - self.get_success(ctx1.__aexit__(None, None, None)) + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. - self.assertEqual(id_gen.get_positions(), {"master": 9}) - id_gen_worker.advance("master", 9) - self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) + async def _get_next_async() -> None: + async with third_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + self.assertEqual( + third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} + ) + self.assertEqual(third_id_gen.get_persisted_upto_position(), 7) -class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): - if not USE_POSTGRES_FOR_TESTS: - skip = "Requires Postgres" + self.get_success(_get_next_async()) - def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None: - """Insert one row as the given instance with given stream_id, updating - the postgres sequence position to match. - """ + self.assertEqual( + third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8} + ) - def _insert(txn: LoggingTransaction) -> None: - txn.execute( - "INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)", - ( - stream_id, - instance_name, - ), - ) + def test_get_next_txn(self) -> None: + """Test that the `get_next_txn` function works correctly.""" - txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) - txn.execute( - """ - INSERT INTO stream_positions VALUES ('test_stream', ?, ?) - ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? - """, - (instance_name, stream_id, stream_id), - ) + id_gen = self._create_id_generator() - self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + def _get_next_txn(txn: LoggingTransaction) -> None: + stream_id = id_gen.get_next_txn(txn) + self.assertEqual(stream_id, 8) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) + + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) def test_get_persisted_upto_position(self) -> None: """Test that `get_persisted_upto_position` correctly tracks updates to @@ -418,9 +530,7 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): id_gen = self._create_id_generator("first", writers=["first", "second"]) - # When the writer is created, it assumes its own position is the current head of - # the sequence - self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5}) + self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) self.assertEqual(id_gen.get_persisted_upto_position(), 5) @@ -437,118 +547,49 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). - def test_multi_instance(self) -> None: - """Test that reads and writes from multiple processes are handled - correctly. + def test_restart_during_out_of_order_persistence(self) -> None: + """Test that restarting a process while another process is writing out + of order updates are handled correctly. """ - self._insert_rows("first", 3) - first_id_gen = self._create_id_generator("first", writers=["first", "second"]) - - self._insert_rows("second", 4) - second_id_gen = self._create_id_generator("second", writers=["first", "second"]) - - self._replicate_all() - self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) - self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) - - self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) - self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) - self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) - - # Try allocating a new ID gen and check that we only see position - # advanced after we leave the context manager. - - async def _get_next_async() -> None: - async with first_id_gen.get_next() as stream_id: - self.assertEqual(stream_id, 8) - - self.assertEqual( - first_id_gen.get_positions(), {"first": 3, "second": 7} - ) - self.assertEqual( - second_id_gen.get_positions(), {"first": 3, "second": 7} - ) - self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) - - self.get_success(_get_next_async()) - - self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7}) - - # However the ID gen on the second instance won't have seen the update - self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) - - # ... but calling `get_next` on the second instance should give a unique - # stream ID - - async def _get_next_async2() -> None: - async with second_id_gen.get_next() as stream_id: - self.assertEqual(stream_id, 9) - - self.assertEqual( - second_id_gen.get_positions(), {"first": 3, "second": 7} - ) - - self.get_success(_get_next_async2()) - - self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) - # If the second ID gen gets told about the first, it correctly updates - second_id_gen.advance("first", 8) - self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) + id_gen = self._create_id_generator() - def test_multi_instance_empty_row(self) -> None: - """Test that reads and writes from multiple processes are handled - correctly, when one of the writers starts without any rows. - """ - # Insert some rows for two out of three of the ID gens. - self._insert_rows("first", 3) - first_id_gen = self._create_id_generator( - "first", writers=["first", "second", "third"] - ) + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - self._insert_rows("second", 4) - second_id_gen = self._create_id_generator( - "second", writers=["first", "second", "third"] - ) - third_id_gen = self._create_id_generator( - "third", writers=["first", "second", "third"] - ) + # Persist two rows at once + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() - self._replicate_all() + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) - self.assertEqual( - first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} - ) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) - self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) - self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7) + self.assertEqual(s1, 8) + self.assertEqual(s2, 9) - self.assertEqual( - second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} - ) - self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) - self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) - self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7) + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - # Try allocating a new ID gen and check that we only see position - # advanced after we leave the context manager. + # We finish persisting the second row before restart + self.get_success(ctx2.__aexit__(None, None, None)) - async def _get_next_async() -> None: - async with third_id_gen.get_next() as stream_id: - self.assertEqual(stream_id, 8) + # We simulate a restart of another worker by just creating a new ID gen. + id_gen_worker = self._create_id_generator("worker") - self.assertEqual( - third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7} - ) - self.assertEqual(third_id_gen.get_persisted_upto_position(), 7) + # Restarted worker should not see the second persisted row + self.assertEqual(id_gen_worker.get_positions(), {"master": 7}) + self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7) - self.get_success(_get_next_async()) + # Now if we persist the first row then both instances should jump ahead + # correctly. + self.get_success(ctx1.__aexit__(None, None, None)) - self.assertEqual( - third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8} - ) + self.assertEqual(id_gen.get_positions(), {"master": 9}) + id_gen_worker.advance("master", 9) + self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) def test_writer_config_change(self) -> None: """Test that changing the writer config correctly works.""" @@ -598,7 +639,7 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) def test_sequence_consistency(self) -> None: - """Test that we correct the sequence if the table and sequence diverges.""" + """Test that we error out if the table and sequence diverges.""" # Prefill with some rows self._insert_row_with_id("master", 3) @@ -609,23 +650,16 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): self.get_success(self.db_pool.runInteraction("_insert", _insert)) - # Creating the ID gen should now fix the inconsistency - id_gen = self._create_id_generator() - - async def _get_next_async() -> None: - async with id_gen.get_next() as stream_id: - self.assertEqual(stream_id, 27) - - self.get_success(_get_next_async()) + # Creating the ID gen should error + with self.assertRaises(IncorrectDatabaseSetup): + self._create_id_generator("first") def test_minimal_local_token(self) -> None: self._insert_rows("first", 3) - first_id_gen = self._create_id_generator("first", writers=["first", "second"]) - self._insert_rows("second", 4) - second_id_gen = self._create_id_generator("second", writers=["first", "second"]) - self._replicate_all() + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3) @@ -638,17 +672,15 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): token when there are no writes. """ self._insert_rows("first", 3) + self._insert_rows("second", 4) + first_id_gen = self._create_id_generator( "first", writers=["first", "second", "third"] ) - - self._insert_rows("second", 4) second_id_gen = self._create_id_generator( "second", writers=["first", "second", "third"] ) - self._replicate_all() - self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7) self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(second_id_gen.get_current_token(), 7) @@ -687,13 +719,68 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): self.assertEqual(second_id_gen.get_current_token(), 7) -class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): +class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests MultiWriterIdGenerator that produce *negative* stream IDs.""" if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - positive = False + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn: LoggingTransaction) -> None: + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator( + self, instance_name: str = "master", writers: Optional[List[str]] = None + ) -> MultiWriterIdGenerator: + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: + return MultiWriterIdGenerator( + conn, + self.db_pool, + notifier=self.hs.get_replication_notifier(), + stream_name="test_stream", + instance_name=instance_name, + tables=[("foobar", "instance_name", "stream_id")], + sequence_name="foobar_seq", + writers=writers or ["master"], + positive=False, + ) + + return self.get_success(self.db_pool.runWithConnection(_create)) + + def _insert_row(self, instance_name: str, stream_id: int) -> None: + """Insert one row as the given instance with given stream_id.""" + + def _insert(txn: LoggingTransaction) -> None: + txn.execute( + "INSERT INTO foobar VALUES (?, ?)", + ( + stream_id, + instance_name, + ), + ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, -stream_id, -stream_id), + ) + + self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) def test_single_instance(self) -> None: """Test that reads and writes from a single process are handled @@ -739,7 +826,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): async def _get_next_async() -> None: async with id_gen_1.get_next() as stream_id: self._insert_row("first", stream_id) - self._replicate("first") + id_gen_2.advance("first", stream_id) self.get_success(_get_next_async()) @@ -751,7 +838,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): async def _get_next_async2() -> None: async with id_gen_2.get_next() as stream_id: self._insert_row("second", stream_id) - self._replicate("second") + id_gen_1.advance("second", stream_id) self.get_success(_get_next_async2()) @@ -761,26 +848,98 @@ class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) -class MultiTableMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase): +class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - tables = ["foobar1", "foobar2"] + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn: LoggingTransaction) -> None: + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar1 ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + txn.execute( + """ + CREATE TABLE foobar2 ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator( + self, instance_name: str = "master", writers: Optional[List[str]] = None + ) -> MultiWriterIdGenerator: + def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: + return MultiWriterIdGenerator( + conn, + self.db_pool, + notifier=self.hs.get_replication_notifier(), + stream_name="test_stream", + instance_name=instance_name, + tables=[ + ("foobar1", "instance_name", "stream_id"), + ("foobar2", "instance_name", "stream_id"), + ], + sequence_name="foobar_seq", + writers=writers or ["master"], + ) + + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def _insert_rows( + self, + table: str, + instance_name: str, + number: int, + update_stream_table: bool = True, + ) -> None: + """Insert N rows as the given instance, inserting with stream IDs pulled + from the postgres sequence. + """ + + def _insert(txn: LoggingTransaction) -> None: + for _ in range(number): + txn.execute( + "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), + (instance_name,), + ) + if update_stream_table: + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + """, + (instance_name,), + ) + + self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) def test_load_existing_stream(self) -> None: """Test creating ID gens with multiple tables that have rows from after the position in `stream_positions` table. """ - self._insert_rows("first", 3, table="foobar1") - first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + self._insert_rows("foobar1", "first", 3) + self._insert_rows("foobar2", "second", 3) + self._insert_rows("foobar2", "second", 1, update_stream_table=False) - self._insert_rows("second", 3, table="foobar2") - self._insert_rows("second", 1, table="foobar2", update_stream_table=False) + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"]) - self._replicate_all() - - self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6}) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 7b5774b8c1..3b5b3c404f 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 Awesome Technologies Innovationslabor GmbH # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9df8ea4ee6..e57566854e 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index 0b984c7ebc..f3a6759405 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 91be2ccb3e..cb459d6b03 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 14e3871dc1..b4277b5436 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -43,6 +42,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): self.assertEqual( UserInfo( + # TODO(paul): Surely this field should be 'user_id', not 'name' user_id=UserID.from_string(self.user_id), is_admin=False, is_guest=False, @@ -56,7 +56,6 @@ class RegistrationStoreTestCase(HomeserverTestCase): locked=False, is_shadow_banned=False, approved=True, - suspended=False, ), (self.get_success(self.store.get_user_by_id(self.user_id))), ) diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py
index a7f7c840f3..5cc8dc263b 100644 --- a/tests/storage/test_relations.py +++ b/tests/storage/test_relations.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
index 909aee043e..942692dd7b 100644 --- a/tests/storage/test_rollback_worker.py +++ b/tests/storage/test_rollback_worker.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 34d6fdb71e..7759e4cf91 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 340642b7e7..1788ca2ab9 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -71,16 +70,17 @@ class EventSearchInsertionTest(HomeserverTestCase): store.search_msgs([room_id], "hi bob", ["content.body"]) ) self.assertEqual(result.get("count"), 1) - self.assertIn("hi", result.get("highlights")) - self.assertIn("bob", result.get("highlights")) + if isinstance(store.database_engine, PostgresEngine): + self.assertIn("hi", result.get("highlights")) + self.assertIn("bob", result.get("highlights")) # Check that search works for an unrelated message result = self.get_success( store.search_msgs([room_id], "another", ["content.body"]) ) self.assertEqual(result.get("count"), 1) - - self.assertIn("another", result.get("highlights")) + if isinstance(store.database_engine, PostgresEngine): + self.assertIn("another", result.get("highlights")) # Check that search works for a search term that overlaps with the message # containing a null byte and an unrelated message. @@ -89,8 +89,8 @@ class EventSearchInsertionTest(HomeserverTestCase): result = self.get_success( store.search_msgs([room_id], "hi alice", ["content.body"]) ) - - self.assertIn("alice", result.get("highlights")) + if isinstance(store.database_engine, PostgresEngine): + self.assertIn("alice", result.get("highlights")) def test_non_string(self) -> None: """Test that non-string `value`s are not inserted into `event_search`. @@ -327,11 +327,9 @@ class MessageSearchTest(HomeserverTestCase): self.assertEqual( result["count"], 1 if expect_to_contain else 0, - ( - f"expected '{query}' to match '{self.PHRASE}'" - if expect_to_contain - else f"'{query}' unexpectedly matched '{self.PHRASE}'" - ), + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'", ) self.assertEqual( len(result["results"]), @@ -347,11 +345,9 @@ class MessageSearchTest(HomeserverTestCase): self.assertEqual( result["count"], 1 if expect_to_contain else 0, - ( - f"expected '{query}' to match '{self.PHRASE}'" - if expect_to_contain - else f"'{query}' unexpectedly matched '{self.PHRASE}'" - ), + f"expected '{query}' to match '{self.PHRASE}'" + if expect_to_contain + else f"'{query}' unexpectedly matched '{self.PHRASE}'", ) self.assertEqual( len(result["results"]), diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 418b556108..31fcc8f829 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -1,8 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -19,28 +17,20 @@ # [This file includes modifications made by New Vector Limited] # # -import logging from typing import List, Optional, Tuple, cast from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import EventTypes, JoinRules, Membership -from synapse.api.room_versions import RoomVersions -from synapse.rest import admin +from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource -from synapse.rest.client import knock, login, room +from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary -from synapse.storage.roommember import MemberSummary from synapse.types import UserID, create_requester from synapse.util import Clock from tests import unittest from tests.server import TestHomeServer from tests.test_utils import event_injection -from tests.unittest import skip_unless - -logger = logging.getLogger(__name__) class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @@ -248,397 +238,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) -class RoomSummaryTestCase(unittest.HomeserverTestCase): - """ - Test `/sync` room summary related logic like `get_room_summary(...)` and - `extract_heroes_from_room_summary(...)` - """ - - servlets = [ - admin.register_servlets, - knock.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sliding_sync_handler = self.hs.get_sliding_sync_handler() - self.store = self.hs.get_datastores().main - - def _assert_member_summary( - self, - actual_member_summary: MemberSummary, - expected_member_list: List[str], - *, - expected_member_count: Optional[int] = None, - ) -> None: - """ - Assert that the `MemberSummary` object has the expected members. - """ - self.assertListEqual( - [ - user_id - for user_id, _membership_event_id in actual_member_summary.members - ], - expected_member_list, - ) - self.assertEqual( - actual_member_summary.count, - ( - expected_member_count - if expected_member_count is not None - else len(expected_member_list) - ), - ) - - def test_get_room_summary_membership(self) -> None: - """ - Test that `get_room_summary(...)` gets every kind of membership when there - aren't that many members in the room. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - _user3_tok = self.login(user3_id, "pass") - user4_id = self.register_user("user4", "pass") - user4_tok = self.login(user4_id, "pass") - user5_id = self.register_user("user5", "pass") - user5_tok = self.login(user5_id, "pass") - - # Setup a room (user1 is the creator and is joined to the room) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # User2 is banned - self.helper.join(room_id, user2_id, tok=user2_tok) - self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok) - - # User3 is invited by user1 - self.helper.invite(room_id, targ=user3_id, tok=user1_tok) - - # User4 leaves - self.helper.join(room_id, user4_id, tok=user4_tok) - self.helper.leave(room_id, user4_id, tok=user4_tok) - - # User5 joins - self.helper.join(room_id, user5_id, tok=user5_tok) - - room_membership_summary = self.get_success(self.store.get_room_summary(room_id)) - empty_ms = MemberSummary([], 0) - - self._assert_member_summary( - room_membership_summary.get(Membership.JOIN, empty_ms), - [user1_id, user5_id], - ) - self._assert_member_summary( - room_membership_summary.get(Membership.INVITE, empty_ms), [user3_id] - ) - self._assert_member_summary( - room_membership_summary.get(Membership.LEAVE, empty_ms), [user4_id] - ) - self._assert_member_summary( - room_membership_summary.get(Membership.BAN, empty_ms), [user2_id] - ) - self._assert_member_summary( - room_membership_summary.get(Membership.KNOCK, empty_ms), - [ - # No one knocked - ], - ) - - def test_get_room_summary_membership_order(self) -> None: - """ - Test that `get_room_summary(...)` stacks our limit of 6 in this order: joins -> - invites -> leave -> everything else (bans/knocks) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - _user3_tok = self.login(user3_id, "pass") - user4_id = self.register_user("user4", "pass") - user4_tok = self.login(user4_id, "pass") - user5_id = self.register_user("user5", "pass") - user5_tok = self.login(user5_id, "pass") - user6_id = self.register_user("user6", "pass") - user6_tok = self.login(user6_id, "pass") - user7_id = self.register_user("user7", "pass") - user7_tok = self.login(user7_id, "pass") - - # Setup the room (user1 is the creator and is joined to the room) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # We expect the order to be joins -> invites -> leave -> bans so setup the users - # *NOT* in that same order to make sure we're actually sorting them. - - # User2 is banned - self.helper.join(room_id, user2_id, tok=user2_tok) - self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok) - - # User3 is invited by user1 - self.helper.invite(room_id, targ=user3_id, tok=user1_tok) - - # User4 leaves - self.helper.join(room_id, user4_id, tok=user4_tok) - self.helper.leave(room_id, user4_id, tok=user4_tok) - - # User5, User6, User7 joins - self.helper.join(room_id, user5_id, tok=user5_tok) - self.helper.join(room_id, user6_id, tok=user6_tok) - self.helper.join(room_id, user7_id, tok=user7_tok) - - room_membership_summary = self.get_success(self.store.get_room_summary(room_id)) - empty_ms = MemberSummary([], 0) - - self._assert_member_summary( - room_membership_summary.get(Membership.JOIN, empty_ms), - [user1_id, user5_id, user6_id, user7_id], - ) - self._assert_member_summary( - room_membership_summary.get(Membership.INVITE, empty_ms), [user3_id] - ) - self._assert_member_summary( - room_membership_summary.get(Membership.LEAVE, empty_ms), [user4_id] - ) - self._assert_member_summary( - room_membership_summary.get(Membership.BAN, empty_ms), - [ - # The banned user is not in the summary because the summary can only fit - # 6 members and prefers everything else before bans - # - # user2_id - ], - # But we still see the count of banned users - expected_member_count=1, - ) - self._assert_member_summary( - room_membership_summary.get(Membership.KNOCK, empty_ms), - [ - # No one knocked - ], - ) - - def test_extract_heroes_from_room_summary_excludes_self(self) -> None: - """ - Test that `extract_heroes_from_room_summary(...)` does not include the user - itself. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Setup the room (user1 is the creator and is joined to the room) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # User2 joins - self.helper.join(room_id, user2_id, tok=user2_tok) - - room_membership_summary = self.get_success(self.store.get_room_summary(room_id)) - - # We first ask from the perspective of a random fake user - hero_user_ids = extract_heroes_from_room_summary( - room_membership_summary, me="@fakeuser" - ) - - # Make sure user1 is in the room (ensure our test setup is correct) - self.assertListEqual(hero_user_ids, [user1_id, user2_id]) - - # Now, we ask for the room summary from the perspective of user1 - hero_user_ids = extract_heroes_from_room_summary( - room_membership_summary, me=user1_id - ) - - # User1 should not be included in the list of heroes because they are the one - # asking - self.assertListEqual(hero_user_ids, [user2_id]) - - def test_extract_heroes_from_room_summary_first_five_joins(self) -> None: - """ - Test that `extract_heroes_from_room_summary(...)` returns the first 5 joins. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass") - user4_id = self.register_user("user4", "pass") - user4_tok = self.login(user4_id, "pass") - user5_id = self.register_user("user5", "pass") - user5_tok = self.login(user5_id, "pass") - user6_id = self.register_user("user6", "pass") - user6_tok = self.login(user6_id, "pass") - user7_id = self.register_user("user7", "pass") - user7_tok = self.login(user7_id, "pass") - - # Setup the room (user1 is the creator and is joined to the room) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # User2 -> User7 joins - self.helper.join(room_id, user2_id, tok=user2_tok) - self.helper.join(room_id, user3_id, tok=user3_tok) - self.helper.join(room_id, user4_id, tok=user4_tok) - self.helper.join(room_id, user5_id, tok=user5_tok) - self.helper.join(room_id, user6_id, tok=user6_tok) - self.helper.join(room_id, user7_id, tok=user7_tok) - - room_membership_summary = self.get_success(self.store.get_room_summary(room_id)) - - hero_user_ids = extract_heroes_from_room_summary( - room_membership_summary, me="@fakuser" - ) - - # First 5 users to join the room - self.assertListEqual( - hero_user_ids, [user1_id, user2_id, user3_id, user4_id, user5_id] - ) - - def test_extract_heroes_from_room_summary_membership_order(self) -> None: - """ - Test that `extract_heroes_from_room_summary(...)` prefers joins/invites over - everything else. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - _user3_tok = self.login(user3_id, "pass") - user4_id = self.register_user("user4", "pass") - user4_tok = self.login(user4_id, "pass") - user5_id = self.register_user("user5", "pass") - user5_tok = self.login(user5_id, "pass") - - # Setup the room (user1 is the creator and is joined to the room) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # We expect the order to be joins -> invites -> leave -> bans so setup the users - # *NOT* in that same order to make sure we're actually sorting them. - - # User2 is banned - self.helper.join(room_id, user2_id, tok=user2_tok) - self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok) - - # User3 is invited by user1 - self.helper.invite(room_id, targ=user3_id, tok=user1_tok) - - # User4 leaves - self.helper.join(room_id, user4_id, tok=user4_tok) - self.helper.leave(room_id, user4_id, tok=user4_tok) - - # User5 joins - self.helper.join(room_id, user5_id, tok=user5_tok) - - room_membership_summary = self.get_success(self.store.get_room_summary(room_id)) - - hero_user_ids = extract_heroes_from_room_summary( - room_membership_summary, me="@fakeuser" - ) - - # Prefer joins -> invites, over everything else - self.assertListEqual( - hero_user_ids, - [ - # The joins - user1_id, - user5_id, - # The invites - user3_id, - ], - ) - - @skip_unless( - False, - "Test is not possible because when everyone leaves the room, " - + "the server is `no_longer_in_room` and we don't have any `current_state_events` to query", - ) - def test_extract_heroes_from_room_summary_fallback_leave_ban(self) -> None: - """ - Test that `extract_heroes_from_room_summary(...)` falls back to leave/ban if - there aren't any joins/invites. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - user3_tok = self.login(user3_id, "pass") - - # Setup the room (user1 is the creator and is joined to the room) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # User2 is banned - self.helper.join(room_id, user2_id, tok=user2_tok) - self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok) - - # User3 leaves - self.helper.join(room_id, user3_id, tok=user3_tok) - self.helper.leave(room_id, user3_id, tok=user3_tok) - - # User1 leaves (we're doing this last because they're the room creator) - self.helper.leave(room_id, user1_id, tok=user1_tok) - - room_membership_summary = self.get_success(self.store.get_room_summary(room_id)) - - hero_user_ids = extract_heroes_from_room_summary( - room_membership_summary, me="@fakeuser" - ) - - # Fallback to people who left -> banned - self.assertListEqual( - hero_user_ids, - [user3_id, user1_id, user3_id], - ) - - def test_extract_heroes_from_room_summary_excludes_knocks(self) -> None: - """ - People who knock on the room have (potentially) never been in the room before - and are total outsiders. Plus the spec doesn't mention them at all for heroes. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Setup the knock room (user1 is the creator and is joined to the room) - knock_room_id = self.helper.create_room_as( - user1_id, tok=user1_tok, room_version=RoomVersions.V7.identifier - ) - self.helper.send_state( - knock_room_id, - EventTypes.JoinRules, - {"join_rule": JoinRules.KNOCK}, - tok=user1_tok, - ) - - # User2 knocks on the room - knock_channel = self.make_request( - "POST", - "/_matrix/client/r0/knock/%s" % (knock_room_id,), - b"{}", - user2_tok, - ) - self.assertEqual(knock_channel.code, 200, knock_channel.result) - - room_membership_summary = self.get_success( - self.store.get_room_summary(knock_room_id) - ) - - hero_user_ids = extract_heroes_from_room_summary( - room_membership_summary, me="@fakeuser" - ) - - # user1 is the creator and is joined to the room (should show up as a hero) - # user2 is knocking on the room (should not show up as a hero) - self.assertListEqual( - hero_user_ids, - [user1_id], - ) - - class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 48f8d1c340..f42a74ac61 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 9dea1af8ea..ab861ead93 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -19,36 +18,19 @@ # # -import logging -from typing import List, Tuple -from unittest.mock import AsyncMock, patch - -from immutabledict import immutabledict +from typing import List from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes +from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.filtering import Filter -from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.events import FrozenEventV3 -from synapse.federation.federation_client import SendJoinResult from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.storage.databases.main.stream import CurrentStateDeltaMembership -from synapse.types import ( - JsonDict, - PersistedEventPosition, - RoomStreamToken, - UserID, - create_requester, -) +from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils.event_injection import create_event -from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase - -logger = logging.getLogger(__name__) +from tests.unittest import HomeserverTestCase class PaginationTestCase(HomeserverTestCase): @@ -285,1170 +267,3 @@ class PaginationTestCase(HomeserverTestCase): } chunk = self._filter_messages(filter) self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none]) - - -class GetLastEventInRoomBeforeStreamOrderingTestCase(HomeserverTestCase): - """ - Test `get_last_event_pos_in_room_before_stream_ordering(...)` - """ - - servlets = [ - admin.register_servlets, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.event_sources = hs.get_event_sources() - - def _update_persisted_instance_name_for_event( - self, event_id: str, instance_name: str - ) -> None: - """ - Update the `instance_name` that persisted the the event in the database. - """ - return self.get_success( - self.store.db_pool.simple_update_one( - "events", - keyvalues={"event_id": event_id}, - updatevalues={"instance_name": instance_name}, - ) - ) - - def _send_event_on_instance( - self, instance_name: str, room_id: str, access_token: str - ) -> Tuple[JsonDict, PersistedEventPosition]: - """ - Send an event in a room and mimic that it was persisted by a specific - instance/worker. - """ - event_response = self.helper.send( - room_id, f"{instance_name} message", tok=access_token - ) - - self._update_persisted_instance_name_for_event( - event_response["event_id"], instance_name - ) - - event_pos = self.get_success( - self.store.get_position_for_event(event_response["event_id"]) - ) - - return event_response, event_pos - - def test_before_room_created(self) -> None: - """ - Test that no event is returned if we are using a token before the room was even created - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - before_room_token = self.event_sources.get_current_token() - - room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - - last_event_result = self.get_success( - self.store.get_last_event_pos_in_room_before_stream_ordering( - room_id=room_id, - end_token=before_room_token.room_key, - ) - ) - - self.assertIsNone(last_event_result) - - def test_after_room_created(self) -> None: - """ - Test that an event is returned if we are using a token after the room was created - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - - after_room_token = self.event_sources.get_current_token() - - last_event_result = self.get_success( - self.store.get_last_event_pos_in_room_before_stream_ordering( - room_id=room_id, - end_token=after_room_token.room_key, - ) - ) - assert last_event_result is not None - last_event_id, _ = last_event_result - - self.assertIsNotNone(last_event_id) - - def test_activity_in_other_rooms(self) -> None: - """ - Test to make sure that the last event in the room is returned even if the - `stream_ordering` has advanced from activity in other rooms. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - event_response = self.helper.send(room_id1, "target!", tok=user1_tok) - # Create another room to advance the stream_ordering - self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - - after_room_token = self.event_sources.get_current_token() - - last_event_result = self.get_success( - self.store.get_last_event_pos_in_room_before_stream_ordering( - room_id=room_id1, - end_token=after_room_token.room_key, - ) - ) - assert last_event_result is not None - last_event_id, _ = last_event_result - - # Make sure it's the event we expect (which also means we know it's from the - # correct room) - self.assertEqual(last_event_id, event_response["event_id"]) - - def test_activity_after_token_has_no_effect(self) -> None: - """ - Test to make sure we return the last event before the token even if there is - activity after it. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - event_response = self.helper.send(room_id1, "target!", tok=user1_tok) - - after_room_token = self.event_sources.get_current_token() - - # Send some events after the token - self.helper.send(room_id1, "after1", tok=user1_tok) - self.helper.send(room_id1, "after2", tok=user1_tok) - - last_event_result = self.get_success( - self.store.get_last_event_pos_in_room_before_stream_ordering( - room_id=room_id1, - end_token=after_room_token.room_key, - ) - ) - assert last_event_result is not None - last_event_id, _ = last_event_result - - # Make sure it's the last event before the token - self.assertEqual(last_event_id, event_response["event_id"]) - - def test_last_event_within_sharded_token(self) -> None: - """ - Test to make sure we can find the last event that that is *within* the sharded - token (a token that has an `instance_map` and looks like - `m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}`). We are specifically testing - that we can find an event within the tokens minimum and instance - `stream_ordering`. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - event_response1, event_pos1 = self._send_event_on_instance( - "worker1", room_id1, user1_tok - ) - event_response2, event_pos2 = self._send_event_on_instance( - "worker1", room_id1, user1_tok - ) - event_response3, event_pos3 = self._send_event_on_instance( - "worker1", room_id1, user1_tok - ) - - # Create another room to advance the `stream_ordering` on the same worker - # so we can sandwich event3 in the middle of the token - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - event_response4, event_pos4 = self._send_event_on_instance( - "worker1", room_id2, user1_tok - ) - - # Assemble a token that encompasses event1 -> event4 on worker1 - end_token = RoomStreamToken( - stream=event_pos2.stream, - instance_map=immutabledict({"worker1": event_pos4.stream}), - ) - - # Send some events after the token - self.helper.send(room_id1, "after1", tok=user1_tok) - self.helper.send(room_id1, "after2", tok=user1_tok) - - last_event_result = self.get_success( - self.store.get_last_event_pos_in_room_before_stream_ordering( - room_id=room_id1, - end_token=end_token, - ) - ) - assert last_event_result is not None - last_event_id, _ = last_event_result - - # Should find closest event before the token in room1 - self.assertEqual( - last_event_id, - event_response3["event_id"], - f"We expected {event_response3['event_id']} but saw {last_event_id} which corresponds to " - + str( - { - "event1": event_response1["event_id"], - "event2": event_response2["event_id"], - "event3": event_response3["event_id"], - } - ), - ) - - def test_last_event_before_sharded_token(self) -> None: - """ - Test to make sure we can find the last event that is *before* the sharded token - (a token that has an `instance_map` and looks like - `m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}`). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - event_response1, event_pos1 = self._send_event_on_instance( - "worker1", room_id1, user1_tok - ) - event_response2, event_pos2 = self._send_event_on_instance( - "worker1", room_id1, user1_tok - ) - - # Create another room to advance the `stream_ordering` on the same worker - room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - event_response3, event_pos3 = self._send_event_on_instance( - "worker1", room_id2, user1_tok - ) - event_response4, event_pos4 = self._send_event_on_instance( - "worker1", room_id2, user1_tok - ) - - # Assemble a token that encompasses event3 -> event4 on worker1 - end_token = RoomStreamToken( - stream=event_pos3.stream, - instance_map=immutabledict({"worker1": event_pos4.stream}), - ) - - # Send some events after the token - self.helper.send(room_id1, "after1", tok=user1_tok) - self.helper.send(room_id1, "after2", tok=user1_tok) - - last_event_result = self.get_success( - self.store.get_last_event_pos_in_room_before_stream_ordering( - room_id=room_id1, - end_token=end_token, - ) - ) - assert last_event_result is not None - last_event_id, _ = last_event_result - - # Should find closest event before the token in room1 - self.assertEqual( - last_event_id, - event_response2["event_id"], - f"We expected {event_response2['event_id']} but saw {last_event_id} which corresponds to " - + str( - { - "event1": event_response1["event_id"], - "event2": event_response2["event_id"], - } - ), - ) - - def test_restrict_event_types(self) -> None: - """ - Test that we only consider given `event_types` when finding the last event - before a token. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) - event_response = self.helper.send_event( - room_id1, - type="org.matrix.special_message", - content={"body": "before1, target!"}, - tok=user1_tok, - ) - self.helper.send(room_id1, "before2", tok=user1_tok) - - after_room_token = self.event_sources.get_current_token() - - # Send some events after the token - self.helper.send_event( - room_id1, - type="org.matrix.special_message", - content={"body": "after1"}, - tok=user1_tok, - ) - self.helper.send(room_id1, "after2", tok=user1_tok) - - last_event_result = self.get_success( - self.store.get_last_event_pos_in_room_before_stream_ordering( - room_id=room_id1, - end_token=after_room_token.room_key, - event_types=["org.matrix.special_message"], - ) - ) - assert last_event_result is not None - last_event_id, _ = last_event_result - - # Make sure it's the last event before the token - self.assertEqual(last_event_id, event_response["event_id"]) - - -class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase): - """ - Test `get_current_state_delta_membership_changes_for_user(...)` - """ - - servlets = [ - admin.register_servlets, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.event_sources = hs.get_event_sources() - self.state_handler = self.hs.get_state_handler() - persistence = hs.get_storage_controllers().persistence - assert persistence is not None - self.persistence = persistence - - def test_returns_membership_events(self) -> None: - """ - A basic test that a membership event in the token range is returned for the user. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) - join_pos = self.get_success( - self.store.get_position_for_event(join_response["event_id"]) - ) - - after_room1_token = self.event_sources.get_current_token() - - membership_changes = self.get_success( - self.store.get_current_state_delta_membership_changes_for_user( - user1_id, - from_key=before_room1_token.room_key, - to_key=after_room1_token.room_key, - ) - ) - - # Let the whole diff show on failure - self.maxDiff = None - self.assertEqual( - membership_changes, - [ - CurrentStateDeltaMembership( - room_id=room_id1, - event_id=join_response["event_id"], - event_pos=join_pos, - membership="join", - sender=user1_id, - prev_event_id=None, - prev_event_pos=None, - prev_membership=None, - prev_sender=None, - ) - ], - ) - - def test_server_left_room_after_us(self) -> None: - """ - Test that when probing over part of the DAG where the server left the room *after - us*, we still see the join and leave changes. - - This is to make sure we play nicely with this behavior: When the server leaves a - room, it will insert new rows with `event_id = null` into the - `current_state_delta_stream` table for all current state. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - room_id1 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - "power_level_content_override": { - "users": { - user2_id: 100, - # Allow user1 to send state in the room - user1_id: 100, - } - } - }, - ) - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - join_pos1 = self.get_success( - self.store.get_position_for_event(join_response1["event_id"]) - ) - # Make sure that random other non-member state that happens to have a `state_key` - # matching the user ID doesn't mess with things. - self.helper.send_state( - room_id1, - event_type="foobarbazdummy", - state_key=user1_id, - body={"foo": "bar"}, - tok=user1_tok, - ) - # User1 should leave the room first - leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - leave_pos1 = self.get_success( - self.store.get_position_for_event(leave_response1["event_id"]) - ) - - # User2 should also leave the room (everyone has left the room which means the - # server is no longer in the room). - self.helper.leave(room_id1, user2_id, tok=user2_tok) - - after_room1_token = self.event_sources.get_current_token() - - # Get the membership changes for the user. - # - # At this point, the `current_state_delta_stream` table should look like the - # following. When the server leaves a room, it will insert new rows with - # `event_id = null` for all current state. - # - # | stream_id | room_id | type | state_key | event_id | prev_event_id | - # |-----------|----------|-----------------------------|----------------|----------|---------------| - # | 2 | !x:test | 'm.room.create' | '' | $xxx | None | - # | 3 | !x:test | 'm.room.member' | '@user2:test' | $aaa | None | - # | 4 | !x:test | 'm.room.history_visibility' | '' | $xxx | None | - # | 4 | !x:test | 'm.room.join_rules' | '' | $xxx | None | - # | 4 | !x:test | 'm.room.power_levels' | '' | $xxx | None | - # | 7 | !x:test | 'm.room.member' | '@user1:test' | $ooo | None | - # | 8 | !x:test | 'foobarbazdummy' | '@user1:test' | $xxx | None | - # | 9 | !x:test | 'm.room.member' | '@user1:test' | $ppp | $ooo | - # | 10 | !x:test | 'foobarbazdummy' | '@user1:test' | None | $xxx | - # | 10 | !x:test | 'm.room.create' | '' | None | $xxx | - # | 10 | !x:test | 'm.room.history_visibility' | '' | None | $xxx | - # | 10 | !x:test | 'm.room.join_rules' | '' | None | $xxx | - # | 10 | !x:test | 'm.room.member' | '@user1:test' | None | $ppp | - # | 10 | !x:test | 'm.room.member' | '@user2:test' | None | $aaa | - # | 10 | !x:test | 'm.room.power_levels' | | None | $xxx | - membership_changes = self.get_success( - self.store.get_current_state_delta_membership_changes_for_user( - user1_id, - from_key=before_room1_token.room_key, - to_key=after_room1_token.room_key, - ) - ) - - # Let the whole diff show on failure - self.maxDiff = None - self.assertEqual( - membership_changes, - [ - CurrentStateDeltaMembership( - room_id=room_id1, - event_id=join_response1["event_id"], - event_pos=join_pos1, - membership="join", - sender=user1_id, - prev_event_id=None, - prev_event_pos=None, - prev_membership=None, - prev_sender=None, - ), - CurrentStateDeltaMembership( - room_id=room_id1, - event_id=leave_response1["event_id"], - event_pos=leave_pos1, - membership="leave", - sender=user1_id, - prev_event_id=join_response1["event_id"], - prev_event_pos=join_pos1, - prev_membership="join", - prev_sender=user1_id, - ), - ], - ) - - def test_server_left_room_after_us_later(self) -> None: - """ - Test when the user leaves the room, then sometime later, everyone else leaves - the room, causing the server to leave the room, we shouldn't see any membership - changes. - - This is to make sure we play nicely with this behavior: When the server leaves a - room, it will insert new rows with `event_id = null` into the - `current_state_delta_stream` table for all current state. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) - # User1 should leave the room first - self.helper.leave(room_id1, user1_id, tok=user1_tok) - - after_user1_leave_token = self.event_sources.get_current_token() - - # User2 should also leave the room (everyone has left the room which means the - # server is no longer in the room). - self.helper.leave(room_id1, user2_id, tok=user2_tok) - - after_server_leave_token = self.event_sources.get_current_token() - - # Join another room as user1 just to advance the stream_ordering and bust - # `_membership_stream_cache` - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id2, user1_id, tok=user1_tok) - - # Get the membership changes for the user. - # - # At this point, the `current_state_delta_stream` table should look like the - # following. When the server leaves a room, it will insert new rows with - # `event_id = null` for all current state. - # - # TODO: Add DB rows to better see what's going on. - membership_changes = self.get_success( - self.store.get_current_state_delta_membership_changes_for_user( - user1_id, - from_key=after_user1_leave_token.room_key, - to_key=after_server_leave_token.room_key, - ) - ) - - # Let the whole diff show on failure - self.maxDiff = None - self.assertEqual( - membership_changes, - [], - ) - - def test_we_cause_server_left_room(self) -> None: - """ - Test that when probing over part of the DAG where the user leaves the room - causing the server to leave the room (because we were the last local user in the - room), we still see the join and leave changes. - - This is to make sure we play nicely with this behavior: When the server leaves a - room, it will insert new rows with `event_id = null` into the - `current_state_delta_stream` table for all current state. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - room_id1 = self.helper.create_room_as( - user2_id, - tok=user2_tok, - extra_content={ - "power_level_content_override": { - "users": { - user2_id: 100, - # Allow user1 to send state in the room - user1_id: 100, - } - } - }, - ) - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - join_pos1 = self.get_success( - self.store.get_position_for_event(join_response1["event_id"]) - ) - # Make sure that random other non-member state that happens to have a `state_key` - # matching the user ID doesn't mess with things. - self.helper.send_state( - room_id1, - event_type="foobarbazdummy", - state_key=user1_id, - body={"foo": "bar"}, - tok=user1_tok, - ) - - # User2 should leave the room first. - self.helper.leave(room_id1, user2_id, tok=user2_tok) - - # User1 (the person we're testing with) should also leave the room (everyone has - # left the room which means the server is no longer in the room). - leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - leave_pos1 = self.get_success( - self.store.get_position_for_event(leave_response1["event_id"]) - ) - - after_room1_token = self.event_sources.get_current_token() - - # Get the membership changes for the user. - # - # At this point, the `current_state_delta_stream` table should look like the - # following. When the server leaves a room, it will insert new rows with - # `event_id = null` for all current state. - # - # | stream_id | room_id | type | state_key | event_id | prev_event_id | - # |-----------|-----------|-----------------------------|---------------|----------|---------------| - # | 2 | '!x:test' | 'm.room.create' | '' | '$xxx' | None | - # | 3 | '!x:test' | 'm.room.member' | '@user2:test' | '$aaa' | None | - # | 4 | '!x:test' | 'm.room.history_visibility' | '' | '$xxx' | None | - # | 4 | '!x:test' | 'm.room.join_rules' | '' | '$xxx' | None | - # | 4 | '!x:test' | 'm.room.power_levels' | '' | '$xxx' | None | - # | 7 | '!x:test' | 'm.room.member' | '@user1:test' | '$ooo' | None | - # | 8 | '!x:test' | 'foobarbazdummy' | '@user1:test' | '$xxx' | None | - # | 9 | '!x:test' | 'm.room.member' | '@user2:test' | '$bbb' | '$aaa' | - # | 10 | '!x:test' | 'foobarbazdummy' | '@user1:test' | None | '$xxx' | - # | 10 | '!x:test' | 'm.room.create' | '' | None | '$xxx' | - # | 10 | '!x:test' | 'm.room.history_visibility' | '' | None | '$xxx' | - # | 10 | '!x:test' | 'm.room.join_rules' | '' | None | '$xxx' | - # | 10 | '!x:test' | 'm.room.member' | '@user1:test' | None | '$ooo' | - # | 10 | '!x:test' | 'm.room.member' | '@user2:test' | None | '$bbb' | - # | 10 | '!x:test' | 'm.room.power_levels' | '' | None | '$xxx' | - membership_changes = self.get_success( - self.store.get_current_state_delta_membership_changes_for_user( - user1_id, - from_key=before_room1_token.room_key, - to_key=after_room1_token.room_key, - ) - ) - - # Let the whole diff show on failure - self.maxDiff = None - self.assertEqual( - membership_changes, - [ - CurrentStateDeltaMembership( - room_id=room_id1, - event_id=join_response1["event_id"], - event_pos=join_pos1, - membership="join", - sender=user1_id, - prev_event_id=None, - prev_event_pos=None, - prev_membership=None, - prev_sender=None, - ), - CurrentStateDeltaMembership( - room_id=room_id1, - event_id=None, # leave_response1["event_id"], - event_pos=leave_pos1, - membership="leave", - sender=None, # user1_id, - prev_event_id=join_response1["event_id"], - prev_event_pos=join_pos1, - prev_membership="join", - prev_sender=user1_id, - ), - ], - ) - - def test_different_user_membership_persisted_in_same_batch(self) -> None: - """ - Test batch of membership events from different users being processed at once. - This will result in all of the memberships being stored in the - `current_state_delta_stream` table with the same `stream_ordering` even though - the individual events have different `stream_ordering`s. - """ - user1_id = self.register_user("user1", "pass") - _user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - user3_id = self.register_user("user3", "pass") - _user3_tok = self.login(user3_id, "pass") - user4_id = self.register_user("user4", "pass") - _user4_tok = self.login(user4_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - # User2 is just the designated person to create the room (we do this across the - # tests to be consistent) - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - - # Persist the user1, user3, and user4 join events in the same batch so they all - # end up in the `current_state_delta_stream` table with the same - # stream_ordering. - join_event3, join_event_context3 = self.get_success( - create_event( - self.hs, - sender=user3_id, - type=EventTypes.Member, - state_key=user3_id, - content={"membership": "join"}, - room_id=room_id1, - ) - ) - # We want to put user1 in the middle of the batch. This way, regardless of the - # implementation that inserts rows into current_state_delta_stream` (whether it - # be minimum/maximum of stream position of the batch), we will still catch bugs. - join_event1, join_event_context1 = self.get_success( - create_event( - self.hs, - sender=user1_id, - type=EventTypes.Member, - state_key=user1_id, - content={"membership": "join"}, - room_id=room_id1, - ) - ) - join_event4, join_event_context4 = self.get_success( - create_event( - self.hs, - sender=user4_id, - type=EventTypes.Member, - state_key=user4_id, - content={"membership": "join"}, - room_id=room_id1, - ) - ) - self.get_success( - self.persistence.persist_events( - [ - (join_event3, join_event_context3), - (join_event1, join_event_context1), - (join_event4, join_event_context4), - ] - ) - ) - - after_room1_token = self.event_sources.get_current_token() - - # Get the membership changes for the user. - # - # At this point, the `current_state_delta_stream` table should look like (notice - # those three memberships at the end with `stream_id=7` because we persisted - # them in the same batch): - # - # | stream_id | room_id | type | state_key | event_id | prev_event_id | - # |-----------|-----------|----------------------------|------------------|----------|---------------| - # | 2 | '!x:test' | 'm.room.create' | '' | '$xxx' | None | - # | 3 | '!x:test' | 'm.room.member' | '@user2:test' | '$xxx' | None | - # | 4 | '!x:test' | 'm.room.history_visibility'| '' | '$xxx' | None | - # | 4 | '!x:test' | 'm.room.join_rules' | '' | '$xxx' | None | - # | 4 | '!x:test' | 'm.room.power_levels' | '' | '$xxx' | None | - # | 7 | '!x:test' | 'm.room.member' | '@user3:test' | '$xxx' | None | - # | 7 | '!x:test' | 'm.room.member' | '@user1:test' | '$xxx' | None | - # | 7 | '!x:test' | 'm.room.member' | '@user4:test' | '$xxx' | None | - membership_changes = self.get_success( - self.store.get_current_state_delta_membership_changes_for_user( - user1_id, - from_key=before_room1_token.room_key, - to_key=after_room1_token.room_key, - ) - ) - - join_pos3 = self.get_success( - self.store.get_position_for_event(join_event3.event_id) - ) - - # Let the whole diff show on failure - self.maxDiff = None - self.assertEqual( - membership_changes, - [ - CurrentStateDeltaMembership( - room_id=room_id1, - event_id=join_event1.event_id, - # Ideally, this would be `join_pos1` (to match the `event_id`) but - # when events are persisted in a batch, they are all stored in the - # `current_state_delta_stream` table with the minimum - # `stream_ordering` from the batch. - event_pos=join_pos3, - membership="join", - sender=user1_id, - prev_event_id=None, - prev_event_pos=None, - prev_membership=None, - prev_sender=None, - ), - ], - ) - - def test_state_reset(self) -> None: - """ - Test a state reset scenario where the user gets removed from the room (when - there is no corresponding leave event) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - join_pos1 = self.get_success( - self.store.get_position_for_event(join_response1["event_id"]) - ) - - before_reset_token = self.event_sources.get_current_token() - - # Send another state event to make a position for the state reset to happen at - dummy_state_response = self.helper.send_state( - room_id1, - event_type="foobarbaz", - state_key="", - body={"foo": "bar"}, - tok=user2_tok, - ) - dummy_state_pos = self.get_success( - self.store.get_position_for_event(dummy_state_response["event_id"]) - ) - - # Mock a state reset removing the membership for user1 in the current state - self.get_success( - self.store.db_pool.simple_delete( - table="current_state_events", - keyvalues={ - "room_id": room_id1, - "type": EventTypes.Member, - "state_key": user1_id, - }, - desc="state reset user in current_state_delta_stream", - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - table="current_state_delta_stream", - values={ - "stream_id": dummy_state_pos.stream, - "room_id": room_id1, - "type": EventTypes.Member, - "state_key": user1_id, - "event_id": None, - "prev_event_id": join_response1["event_id"], - "instance_name": dummy_state_pos.instance_name, - }, - desc="state reset user in current_state_delta_stream", - ) - ) - - # Manually bust the cache since we we're just manually messing with the database - # and not causing an actual state reset. - self.store._membership_stream_cache.entity_has_changed( - user1_id, dummy_state_pos.stream - ) - - after_reset_token = self.event_sources.get_current_token() - - membership_changes = self.get_success( - self.store.get_current_state_delta_membership_changes_for_user( - user1_id, - from_key=before_reset_token.room_key, - to_key=after_reset_token.room_key, - ) - ) - - # Let the whole diff show on failure - self.maxDiff = None - self.assertEqual( - membership_changes, - [ - CurrentStateDeltaMembership( - room_id=room_id1, - event_id=None, - event_pos=dummy_state_pos, - membership="leave", - sender=None, # user1_id, - prev_event_id=join_response1["event_id"], - prev_event_pos=join_pos1, - prev_membership="join", - prev_sender=user1_id, - ), - ], - ) - - def test_excluded_room_ids(self) -> None: - """ - Test that the `excluded_room_ids` option excludes changes from the specified rooms. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_room1_token = self.event_sources.get_current_token() - - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - join_pos1 = self.get_success( - self.store.get_position_for_event(join_response1["event_id"]) - ) - - room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok) - join_pos2 = self.get_success( - self.store.get_position_for_event(join_response2["event_id"]) - ) - - after_room1_token = self.event_sources.get_current_token() - - # First test the the room is returned without the `excluded_room_ids` option - membership_changes = self.get_success( - self.store.get_current_state_delta_membership_changes_for_user( - user1_id, - from_key=before_room1_token.room_key, - to_key=after_room1_token.room_key, - ) - ) - - # Let the whole diff show on failure - self.maxDiff = None - self.assertEqual( - membership_changes, - [ - CurrentStateDeltaMembership( - room_id=room_id1, - event_id=join_response1["event_id"], - event_pos=join_pos1, - membership="join", - sender=user1_id, - prev_event_id=None, - prev_event_pos=None, - prev_membership=None, - prev_sender=None, - ), - CurrentStateDeltaMembership( - room_id=room_id2, - event_id=join_response2["event_id"], - event_pos=join_pos2, - membership="join", - sender=user1_id, - prev_event_id=None, - prev_event_pos=None, - prev_membership=None, - prev_sender=None, - ), - ], - ) - - # The test that `excluded_room_ids` excludes room2 as expected - membership_changes = self.get_success( - self.store.get_current_state_delta_membership_changes_for_user( - user1_id, - from_key=before_room1_token.room_key, - to_key=after_room1_token.room_key, - excluded_room_ids=[room_id2], - ) - ) - - # Let the whole diff show on failure - self.maxDiff = None - self.assertEqual( - membership_changes, - [ - CurrentStateDeltaMembership( - room_id=room_id1, - event_id=join_response1["event_id"], - event_pos=join_pos1, - membership="join", - sender=user1_id, - prev_event_id=None, - prev_event_pos=None, - prev_membership=None, - prev_sender=None, - ) - ], - ) - - -class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase( - FederatingHomeserverTestCase -): - """ - Test `get_current_state_delta_membership_changes_for_user(...)` when joining remote federated rooms. - """ - - servlets = [ - admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sliding_sync_handler = self.hs.get_sliding_sync_handler() - self.store = self.hs.get_datastores().main - self.event_sources = hs.get_event_sources() - self.room_member_handler = hs.get_room_member_handler() - - def test_remote_join(self) -> None: - """ - Test remote join where the first rows in `current_state_delta_stream` will just - be the state when you joined the remote room. - """ - user1_id = self.register_user("user1", "pass") - _user1_tok = self.login(user1_id, "pass") - - before_join_token = self.event_sources.get_current_token() - - intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}" - - # Remotely join a room on another homeserver. - # - # To do this we have to mock the responses from the remote homeserver. We also - # patch out a bunch of event checks on our end. - create_event_source = { - "auth_events": [], - "content": { - "creator": f"@creator:{self.OTHER_SERVER_NAME}", - "room_version": self.hs.config.server.default_room_version.identifier, - }, - "depth": 0, - "origin_server_ts": 0, - "prev_events": [], - "room_id": intially_unjoined_room_id, - "sender": f"@creator:{self.OTHER_SERVER_NAME}", - "state_key": "", - "type": EventTypes.Create, - } - self.add_hashes_and_signatures_from_other_server( - create_event_source, - self.hs.config.server.default_room_version, - ) - create_event = FrozenEventV3( - create_event_source, - self.hs.config.server.default_room_version, - {}, - None, - ) - creator_join_event_source = { - "auth_events": [create_event.event_id], - "content": { - "membership": "join", - }, - "depth": 1, - "origin_server_ts": 1, - "prev_events": [], - "room_id": intially_unjoined_room_id, - "sender": f"@creator:{self.OTHER_SERVER_NAME}", - "state_key": f"@creator:{self.OTHER_SERVER_NAME}", - "type": EventTypes.Member, - } - self.add_hashes_and_signatures_from_other_server( - creator_join_event_source, - self.hs.config.server.default_room_version, - ) - creator_join_event = FrozenEventV3( - creator_join_event_source, - self.hs.config.server.default_room_version, - {}, - None, - ) - - # Our local user is going to remote join the room - join_event_source = { - "auth_events": [create_event.event_id], - "content": {"membership": "join"}, - "depth": 1, - "origin_server_ts": 100, - "prev_events": [creator_join_event.event_id], - "sender": user1_id, - "state_key": user1_id, - "room_id": intially_unjoined_room_id, - "type": EventTypes.Member, - } - add_hashes_and_signatures( - self.hs.config.server.default_room_version, - join_event_source, - self.hs.hostname, - self.hs.signing_key, - ) - join_event = FrozenEventV3( - join_event_source, - self.hs.config.server.default_room_version, - {}, - None, - ) - - mock_make_membership_event = AsyncMock( - return_value=( - self.OTHER_SERVER_NAME, - join_event, - self.hs.config.server.default_room_version, - ) - ) - mock_send_join = AsyncMock( - return_value=SendJoinResult( - join_event, - self.OTHER_SERVER_NAME, - state=[create_event, creator_join_event], - auth_chain=[create_event, creator_join_event], - partial_state=False, - servers_in_room=frozenset(), - ) - ) - - with patch.object( - self.room_member_handler.federation_handler.federation_client, - "make_membership_event", - mock_make_membership_event, - ), patch.object( - self.room_member_handler.federation_handler.federation_client, - "send_join", - mock_send_join, - ), patch( - "synapse.event_auth._is_membership_change_allowed", - return_value=None, - ), patch( - "synapse.handlers.federation_event.check_state_dependent_auth_rules", - return_value=None, - ): - self.get_success( - self.room_member_handler.update_membership( - requester=create_requester(user1_id), - target=UserID.from_string(user1_id), - room_id=intially_unjoined_room_id, - action=Membership.JOIN, - remote_room_hosts=[self.OTHER_SERVER_NAME], - ) - ) - - after_join_token = self.event_sources.get_current_token() - - # Get the membership changes for the user. - # - # At this point, the `current_state_delta_stream` table should look like the - # following. Notice that all of the events are at the same `stream_id` because - # the current state starts out where we remotely joined: - # - # | stream_id | room_id | type | state_key | event_id | prev_event_id | - # |-----------|------------------------------|-----------------|------------------------------|----------|---------------| - # | 2 | '!example:other.example.com' | 'm.room.member' | '@user1:test' | '$xxx' | None | - # | 2 | '!example:other.example.com' | 'm.room.create' | '' | '$xxx' | None | - # | 2 | '!example:other.example.com' | 'm.room.member' | '@creator:other.example.com' | '$xxx' | None | - membership_changes = self.get_success( - self.store.get_current_state_delta_membership_changes_for_user( - user1_id, - from_key=before_join_token.room_key, - to_key=after_join_token.room_key, - ) - ) - - join_pos = self.get_success( - self.store.get_position_for_event(join_event.event_id) - ) - - # Let the whole diff show on failure - self.maxDiff = None - self.assertEqual( - membership_changes, - [ - CurrentStateDeltaMembership( - room_id=intially_unjoined_room_id, - event_id=join_event.event_id, - event_pos=join_pos, - membership="join", - sender=user1_id, - prev_event_id=None, - prev_event_pos=None, - prev_membership=None, - prev_sender=None, - ), - ], - ) diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index 5d58521810..4e99355fda 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
index 4f652fc179..9912f1fa28 100644 --- a/tests/storage/test_unsafe_locale.py +++ b/tests/storage/test_unsafe_locale.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index c26932069f..b4289b803a 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -711,10 +710,6 @@ class UserDirectoryICUTestCase(HomeserverTestCase): ), ) - self.assertEqual(_parse_words_with_icu("user-1"), ["user-1"]) - self.assertEqual(_parse_words_with_icu("user-ab"), ["user-ab"]) - self.assertEqual(_parse_words_with_icu("user.--1"), ["user", "-1"]) - def test_regex_word_boundary_punctuation(self) -> None: """ Tests the behaviour of punctuation with the non-ICU tokeniser diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py
index 177da340e5..a39f44bc26 100644 --- a/tests/storage/test_user_filters.py +++ b/tests/storage/test_user_filters.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/util/__init__.py b/tests/storage/util/__init__.py
index dab387a504..3d833a2e44 100644 --- a/tests/storage/util/__init__.py +++ b/tests/storage/util/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index 1e5663f137..a9e6d1f5c5 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 18792fdee3..55da82854a 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 6d1ae4c8d7..972284e55b 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2018-2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/test_federation.py b/tests/test_federation.py
index 4e9adc0625..debba61b42 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index 16206d5a97..b2b9cedf8f 100644 --- a/tests/test_phone_home.py +++ b/tests/test_phone_home.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/test_server.py b/tests/test_server.py
index 9ff2589497..0910ea5f28 100644 --- a/tests/test_server.py +++ b/tests/test_server.py
@@ -392,7 +392,8 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): ) self.assertEqual(channel.code, 301) - location_headers = channel.headers.getRawHeaders(b"Location", []) + headers = channel.result["headers"] + location_headers = [v for k, v in headers if k == b"Location"] self.assertEqual(location_headers, [b"/look/an/eagle"]) def test_redirect_exception_with_cookie(self) -> None: @@ -414,10 +415,10 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): ) self.assertEqual(channel.code, 304) - headers = channel.headers - location_headers = headers.getRawHeaders(b"Location", []) + headers = channel.result["headers"] + location_headers = [v for k, v in headers if k == b"Location"] self.assertEqual(location_headers, [b"/no/over/there"]) - cookies_headers = headers.getRawHeaders(b"Set-Cookie", []) + cookies_headers = [v for k, v in headers if k == b"Set-Cookie"] self.assertEqual(cookies_headers, [b"session=yespls"]) def test_head_request(self) -> None: diff --git a/tests/test_state.py b/tests/test_state.py
index 311a590693..ac3f6b100d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index c52f963a7e..21b44510bd 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/test_types.py b/tests/test_types.py
index 00adc65a5a..0d9cb17286 100644 --- a/tests/test_types.py +++ b/tests/test_types.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -19,18 +18,9 @@ # # -from typing import Type -from unittest import skipUnless - -from immutabledict import immutabledict -from parameterized import parameterized_class - from synapse.api.errors import SynapseError from synapse.types import ( - AbstractMultiWriterStreamToken, - MultiWriterStreamToken, RoomAlias, - RoomStreamToken, UserID, get_domain_from_id, get_localpart_from_id, @@ -38,7 +28,6 @@ from synapse.types import ( ) from tests import unittest -from tests.utils import USE_POSTGRES_FOR_TESTS class IsMineIDTests(unittest.HomeserverTestCase): @@ -137,64 +126,3 @@ class MapUsernameTestCase(unittest.TestCase): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast") - - -@parameterized_class( - ("token_type",), - [ - (MultiWriterStreamToken,), - (RoomStreamToken,), - ], - class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}", -) -class MultiWriterTokenTestCase(unittest.HomeserverTestCase): - """Tests for the different types of multi writer tokens.""" - - token_type: Type[AbstractMultiWriterStreamToken] - - def test_basic_token(self) -> None: - """Test that a simple stream token can be serialized and unserialized""" - store = self.hs.get_datastores().main - - token = self.token_type(stream=5) - - string_token = self.get_success(token.to_string(store)) - - if isinstance(token, RoomStreamToken): - self.assertEqual(string_token, "s5") - else: - self.assertEqual(string_token, "5") - - parsed_token = self.get_success(self.token_type.parse(store, string_token)) - self.assertEqual(parsed_token, token) - - @skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres") - def test_instance_map(self) -> None: - """Test for stream token with instance map""" - store = self.hs.get_datastores().main - - token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6})) - - string_token = self.get_success(token.to_string(store)) - self.assertEqual(string_token, "m5~1.6") - - parsed_token = self.get_success(self.token_type.parse(store, string_token)) - self.assertEqual(parsed_token, token) - - def test_instance_map_assertion(self) -> None: - """Test that we assert values in the instance map are greater than the - min stream position""" - - with self.assertRaises(ValueError): - self.token_type(stream=5, instance_map=immutabledict({"foo": 4})) - - with self.assertRaises(ValueError): - self.token_type(stream=5, instance_map=immutabledict({"foo": 5})) - - def test_parse_bad_token(self) -> None: - """Test that we can parse tokens produced by a bug in Synapse of the - form `m5~`""" - store = self.hs.get_datastores().main - - parsed_token = self.get_success(self.token_type.parse(store, "m5~")) - self.assertEqual(parsed_token, self.token_type(stream=5)) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 4ab42a02b9..0908e80c14 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 35b3245708..a82d25eaf2 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -125,15 +124,13 @@ async def mark_event_as_partial_state( in this table). """ store = hs.get_datastores().main - # Use the store helper to insert into the database so the caches are busted - await store.store_partial_state_room( - room_id=room_id, - servers={hs.hostname}, - device_lists_stream_id=0, - joined_via=hs.hostname, + await store.db_pool.simple_upsert( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={"room_id": room_id}, ) - # FIXME: Bust the cache await store.db_pool.simple_insert( table="partial_state_events", values={ diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index a0f39cb130..7ae1423de6 100644 --- a/tests/test_utils/html_parsers.py +++ b/tests/test_utils/html_parsers.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 6c4be1c1f8..9b204b196b 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 89cbe4e54b..e51f72d65f 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py
@@ -21,19 +21,13 @@ import logging from typing import Optional from unittest.mock import patch -from synapse.api.constants import EventUnsignedContentFields from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext -from synapse.rest import admin -from synapse.rest.client import login, room -from synapse.server import HomeServer -from synapse.types import create_requester +from synapse.types import JsonDict, create_requester from synapse.visibility import filter_events_for_client, filter_events_for_server from tests import unittest -from tests.test_utils.event_injection import inject_event, inject_member_event -from tests.unittest import HomeserverTestCase from tests.utils import create_room logger = logging.getLogger(__name__) @@ -62,31 +56,15 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # # before we do that, we persist some other events to act as state. - self.get_success( - inject_visibility_event(self.hs, TEST_ROOM_ID, "@admin:hs", "joined") - ) + self._inject_visibility("@admin:hs", "joined") for i in range(10): - self.get_success( - inject_member_event( - self.hs, - TEST_ROOM_ID, - "@resident%i:hs" % i, - "join", - ) - ) + self._inject_room_member("@resident%i:hs" % i) events_to_filter = [] for i in range(10): - evt = self.get_success( - inject_member_event( - self.hs, - TEST_ROOM_ID, - "@user%i:%s" % (i, "test_server" if i == 5 else "other_server"), - "join", - extra_content={"a": "b"}, - ) - ) + user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") + evt = self._inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) filtered = self.get_success( @@ -112,19 +90,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): def test_filter_outlier(self) -> None: # outlier events must be returned, for the good of the collective federation - self.get_success( - inject_member_event( - self.hs, - TEST_ROOM_ID, - "@resident:remote_hs", - "join", - ) - ) - self.get_success( - inject_visibility_event( - self.hs, TEST_ROOM_ID, "@resident:remote_hs", "joined" - ) - ) + self._inject_room_member("@resident:remote_hs") + self._inject_visibility("@resident:remote_hs", "joined") outlier = self._inject_outlier() self.assertEqual( @@ -143,9 +110,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): ) # it should also work when there are other events in the list - evt = self.get_success( - inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs") - ) + evt = self._inject_message("@unerased:local_hs") filtered = self.get_success( filter_events_for_server( @@ -185,34 +150,19 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # change in the middle of them. events_to_filter = [] - evt = self.get_success( - inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs") - ) + evt = self._inject_message("@unerased:local_hs") events_to_filter.append(evt) - evt = self.get_success( - inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs") - ) + evt = self._inject_message("@erased:local_hs") events_to_filter.append(evt) - evt = self.get_success( - inject_member_event( - self.hs, - TEST_ROOM_ID, - "@joiner:remote_hs", - "join", - ) - ) + evt = self._inject_room_member("@joiner:remote_hs") events_to_filter.append(evt) - evt = self.get_success( - inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs") - ) + evt = self._inject_message("@unerased:local_hs") events_to_filter.append(evt) - evt = self.get_success( - inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs") - ) + evt = self._inject_message("@erased:local_hs") events_to_filter.append(evt) # the erasey user gets erased @@ -250,142 +200,99 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): for i in (1, 4): self.assertNotIn("body", filtered[i].content) - def _inject_outlier(self) -> EventBase: + def _inject_visibility(self, user_id: str, visibility: str) -> EventBase: + content = {"history_visibility": visibility} builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { - "type": "m.room.member", - "sender": "@test:user", - "state_key": "@test:user", + "type": "m.room.history_visibility", + "sender": user_id, + "state_key": "", "room_id": TEST_ROOM_ID, - "content": {"membership": "join"}, + "content": content, }, ) - event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) - event.internal_metadata.outlier = True - self.get_success( - self._persistence.persist_event( - event, EventContext.for_outlier(self._storage_controllers) - ) + event, unpersisted_context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(self._persistence.persist_event(event, context)) return event + def _inject_room_member( + self, + user_id: str, + membership: str = "join", + extra_content: Optional[JsonDict] = None, + ) -> EventBase: + content = {"membership": membership} + content.update(extra_content or {}) + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": "m.room.member", + "sender": user_id, + "state_key": user_id, + "room_id": TEST_ROOM_ID, + "content": content, + }, + ) -class FilterEventsForClientTestCase(HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def test_joined_history_visibility(self) -> None: - # User joins and leaves room. Should be able to see the join and leave, - # and messages sent between the two, but not before or after. + event, unpersisted_context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + context = self.get_success(unpersisted_context.persist(event)) - self.register_user("resident", "p1") - resident_token = self.login("resident", "p1") - room_id = self.helper.create_room_as("resident", tok=resident_token) + self.get_success(self._persistence.persist_event(event, context)) + return event - self.get_success( - inject_visibility_event(self.hs, room_id, "@resident:test", "joined") - ) - before_event = self.get_success( - inject_message_event(self.hs, room_id, "@resident:test", body="before") - ) - join_event = self.get_success( - inject_member_event(self.hs, room_id, "@joiner:test", "join") - ) - during_event = self.get_success( - inject_message_event(self.hs, room_id, "@resident:test", body="during") - ) - leave_event = self.get_success( - inject_member_event(self.hs, room_id, "@joiner:test", "leave") - ) - after_event = self.get_success( - inject_message_event(self.hs, room_id, "@resident:test", body="after") + def _inject_message( + self, user_id: str, content: Optional[JsonDict] = None + ) -> EventBase: + if content is None: + content = {"body": "testytest", "msgtype": "m.text"} + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": "m.room.message", + "sender": user_id, + "room_id": TEST_ROOM_ID, + "content": content, + }, ) - # We have to reload the events from the db, to ensure that prev_content is - # populated. - events_to_filter = [ - self.get_success( - self.hs.get_storage_controllers().main.get_event( - e.event_id, - get_prev_content=True, - ) - ) - for e in [ - before_event, - join_event, - during_event, - leave_event, - after_event, - ] - ] - - # Now run the events through the filter, and check that we can see the events - # we expect, and that the membership prop is as expected. - # - # We deliberately do the queries for both users upfront; this simulates - # concurrent queries on the server, and helps ensure that we aren't - # accidentally serving the same event object (with the same unsigned.membership - # property) to both users. - joiner_filtered_events = self.get_success( - filter_events_for_client( - self.hs.get_storage_controllers(), - "@joiner:test", - events_to_filter, - ) - ) - resident_filtered_events = self.get_success( - filter_events_for_client( - self.hs.get_storage_controllers(), - "@resident:test", - events_to_filter, - ) + event, unpersisted_context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) - # The joiner should be able to seem the join and leave, - # and messages sent between the two, but not before or after. - self.assertEqual( - [e.event_id for e in [join_event, during_event, leave_event]], - [e.event_id for e in joiner_filtered_events], - ) - self.assertEqual( - ["join", "join", "leave"], - [ - e.unsigned[EventUnsignedContentFields.MEMBERSHIP] - for e in joiner_filtered_events - ], - ) + self.get_success(self._persistence.persist_event(event, context)) + return event - # The resident user should see all the events. - self.assertEqual( - [ - e.event_id - for e in [ - before_event, - join_event, - during_event, - leave_event, - after_event, - ] - ], - [e.event_id for e in resident_filtered_events], + def _inject_outlier(self) -> EventBase: + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": "m.room.member", + "sender": "@test:user", + "state_key": "@test:user", + "room_id": TEST_ROOM_ID, + "content": {"membership": "join"}, + }, ) - self.assertEqual( - ["join", "join", "join", "join", "join"], - [ - e.unsigned[EventUnsignedContentFields.MEMBERSHIP] - for e in resident_filtered_events - ], + + event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) + event.internal_metadata.outlier = True + self.get_success( + self._persistence.persist_event( + event, EventContext.for_outlier(self._storage_controllers) + ) ) + return event -class FilterEventsOutOfBandEventsForClientTestCase( - unittest.FederatingHomeserverTestCase -): +class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): def test_out_of_band_invite_rejection(self) -> None: # this is where we have received an invite event over federation, and then # rejected it. @@ -434,23 +341,15 @@ class FilterEventsOutOfBandEventsForClientTestCase( ) # the invited user should be able to see both the invite and the rejection - filtered_events = self.get_success( - filter_events_for_client( - self.hs.get_storage_controllers(), - "@user:test", - [invite_event, reject_event], - ) - ) - self.assertEqual( - [e.event_id for e in filtered_events], - [e.event_id for e in [invite_event, reject_event]], - ) self.assertEqual( - ["invite", "leave"], - [ - e.unsigned[EventUnsignedContentFields.MEMBERSHIP] - for e in filtered_events - ], + self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@user:test", + [invite_event, reject_event], + ) + ), + [invite_event, reject_event], ) # other users should see neither @@ -464,34 +363,3 @@ class FilterEventsOutOfBandEventsForClientTestCase( ), [], ) - - -async def inject_visibility_event( - hs: HomeServer, - room_id: str, - sender: str, - visibility: str, -) -> EventBase: - return await inject_event( - hs, - type="m.room.history_visibility", - sender=sender, - state_key="", - room_id=room_id, - content={"history_visibility": visibility}, - ) - - -async def inject_message_event( - hs: HomeServer, - room_id: str, - sender: str, - body: Optional[str] = "testytest", -) -> EventBase: - return await inject_event( - hs, - type="m.room.message", - sender=sender, - room_id=room_id, - content={"body": body, "msgtype": "m.text"}, - ) diff --git a/tests/unittest.py b/tests/unittest.py
index 4aa7f56106..c2e120ffa6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -1,8 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 Matrix.org Federation C.I.C -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -28,7 +26,6 @@ import logging import secrets import time from typing import ( - AbstractSet, Any, Awaitable, Callable, @@ -110,7 +107,8 @@ class _TypedFailure(Generic[_ExcType], Protocol): """Extension to twisted.Failure, where the 'value' has a certain type.""" @property - def value(self) -> _ExcType: ... + def value(self) -> _ExcType: + ... def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]: @@ -270,56 +268,6 @@ class TestCase(unittest.TestCase): required[key], actual[key], msg="%s mismatch. %s" % (key, actual) ) - def assertIncludes( - self, - actual_items: AbstractSet[str], - expected_items: AbstractSet[str], - exact: bool = False, - message: Optional[str] = None, - ) -> None: - """ - Assert that all of the `expected_items` are included in the `actual_items`. - - This assert could also be called `assertContains`, `assertItemsInSet` - - Args: - actual_items: The container - expected_items: The items to check for in the container - exact: Whether the actual state should be exactly equal to the expected - state (no extras). - message: Optional message to include in the failure message. - """ - # Check that each set has the same items - if exact and actual_items == expected_items: - return - # Check for a superset - elif not exact and actual_items >= expected_items: - return - - expected_lines: List[str] = [] - for expected_item in expected_items: - is_expected_in_actual = expected_item in actual_items - expected_lines.append( - "{} {}".format(" " if is_expected_in_actual else "?", expected_item) - ) - - actual_lines: List[str] = [] - for actual_item in actual_items: - is_actual_in_expected = actual_item in expected_items - actual_lines.append( - "{} {}".format("+" if is_actual_in_expected else " ", actual_item) - ) - - newline = "\n" - expected_string = f"Expected items to be in actual ('?' = missing expected items):\n {{\n{newline.join(expected_lines)}\n }}" - actual_string = f"Actual ('+' = found expected items):\n {{\n{newline.join(actual_lines)}\n }}" - first_message = ( - "Items must match exactly" if exact else "Some expected items are missing." - ) - diff_message = f"{first_message}\n{expected_string}\n{actual_string}" - - self.fail(f"{diff_message}\n{message}") - def DEBUG(target: TV) -> TV: """A decorator to set the .loglevel attribute to logging.DEBUG. @@ -395,8 +343,6 @@ class HomeserverTestCase(TestCase): self._hs_args = {"clock": self.clock, "reactor": self.reactor} self.hs = self.make_homeserver(self.reactor, self.clock) - self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False - # Honour the `use_frozen_dicts` config option. We have to do this # manually because this is taken care of in the app `start` code, which # we don't run. Plus we want to reset it on tearDown. @@ -576,7 +522,6 @@ class HomeserverTestCase(TestCase): request: Type[Request] = SynapseRequest, shorthand: bool = True, federation_auth_origin: Optional[bytes] = None, - content_type: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, custom_headers: Optional[Iterable[CustomHeaderType]] = None, @@ -595,9 +540,6 @@ class HomeserverTestCase(TestCase): with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake Authorization header pretenting to be the given server name. - - content_type: The content-type to use for the request. If not set then will default to - application/json unless content_is_form is true. content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. @@ -623,7 +565,6 @@ class HomeserverTestCase(TestCase): request, shorthand, federation_auth_origin, - content_type, content_is_form, await_result, custom_headers, @@ -690,13 +631,13 @@ class HomeserverTestCase(TestCase): return self.successResultOf(deferred) def get_failure( - self, d: Awaitable[Any], exc: Type[_ExcType], by: float = 0.0 + self, d: Awaitable[Any], exc: Type[_ExcType] ) -> _TypedFailure[_ExcType]: """ Run a Deferred and get a Failure from it. The failure must be of the type `exc`. """ deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type] - self.pump(by) + self.pump() return self.failureResultOf(deferred, exc) def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV: diff --git a/tests/util/__init__.py b/tests/util/__init__.py
index fcd2134c89..3d833a2e44 100644 --- a/tests/util/__init__.py +++ b/tests/util/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/caches/__init__.py b/tests/util/caches/__init__.py
index abc9f6d57d..3d833a2e44 100644 --- a/tests/util/caches/__init__.py +++ b/tests/util/caches/__init__.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2017 Vector Creations Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py
index d6b3d47422..f9e183b65c 100644 --- a/tests/util/caches/test_cached_call.py +++ b/tests/util/caches/test_cached_call.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index f99f99237e..df1b4f587f 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 6af9dfaf56..d403014aad 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py
index e350967bba..57e7b39b69 100644 --- a/tests/util/caches/test_response_cache.py +++ b/tests/util/caches/test_response_cache.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py
index 2dcf3a3412..897d0b6b31 100644 --- a/tests/util/test_batching_queue.py +++ b/tests/util/test_batching_queue.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 13a4e6ddaa..65e0793d64 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -21,7 +20,6 @@ from contextlib import contextmanager from os import PathLike -from pathlib import Path from typing import Generator, Optional, Union from unittest.mock import patch @@ -42,7 +40,7 @@ class DummyDistribution(metadata.Distribution): def version(self) -> str: return self._version - def locate_file(self, path: Union[str, PathLike]) -> Path: + def locate_file(self, path: Union[str, PathLike]) -> PathLike: raise NotImplementedError() def read_text(self, filename: str) -> None: diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index 5055e4aead..c598a6db06 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index e97e5cf77d..afbf5a926a 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2017 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 7a593cc683..5ef691cf03 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 7cbb1007da..c44db6bf8e 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -34,7 +33,8 @@ from tests import unittest class UnblockFunction(Protocol): - def __call__(self, pump_reactor: bool = True) -> None: ... + def __call__(self, pump_reactor: bool = True) -> None: + ... class LinearizerTestCase(unittest.TestCase): diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index f7c5f5faca..40380f6a3e 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 3f0d8139f8..2283a1dbba 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -383,34 +382,3 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase): # the items should still be in the cache self.assertEqual(cache.get("key1"), 1) self.assertEqual(cache.get("key2"), 2) - - -class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase): - def test_invalidate_simple(self) -> None: - cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v)) - cache["key1"] = 1 - cache["key2"] = 2 - - cache.invalidate_on_extra_index("key1") - self.assertEqual(cache.get("key1"), 1) - self.assertEqual(cache.get("key2"), 2) - - cache.invalidate_on_extra_index("1") - self.assertEqual(cache.get("key1"), None) - self.assertEqual(cache.get("key2"), 2) - - def test_invalidate_multi(self) -> None: - cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v)) - cache["key1"] = 1 - cache["key2"] = 1 - cache["key3"] = 2 - - cache.invalidate_on_extra_index("key1") - self.assertEqual(cache.get("key1"), 1) - self.assertEqual(cache.get("key2"), 1) - self.assertEqual(cache.get("key3"), 2) - - cache.invalidate_on_extra_index("1") - self.assertEqual(cache.get("key1"), None) - self.assertEqual(cache.get("key2"), None) - self.assertEqual(cache.get("key3"), 2) diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py
index e0507667a2..fa3fa3ff3c 100644 --- a/tests/util/test_macaroons.py +++ b/tests/util/test_macaroons.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
index 7bb45f9bf2..2aeba9ab33 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 2c286c19a2..98296e611c 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index 12f821d684..37f97a622d 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index af1199ef8a..3df053493b 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py
@@ -1,5 +1,3 @@ -from parameterized import parameterized - from synapse.util.caches.stream_change_cache import StreamChangeCache from tests import unittest @@ -163,8 +161,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): self.assertFalse(cache.has_any_entity_changed(2)) self.assertFalse(cache.has_any_entity_changed(3)) - @parameterized.expand([(0,), (1000000000,)]) - def test_get_entities_changed(self, perf_factor: int) -> None: + def test_get_entities_changed(self) -> None: """ StreamChangeCache.get_entities_changed will return the entities in the given list that have changed since the provided stream ID. If the @@ -181,9 +178,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # get the ones after that point. self.assertEqual( cache.get_entities_changed( - ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], - stream_pos=2, - _perf_factor=perf_factor, + ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 ), {"bar@baz.net", "user@elsewhere.org"}, ) @@ -200,7 +195,6 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): "not@here.website", ], stream_pos=2, - _perf_factor=perf_factor, ), {"bar@baz.net", "user@elsewhere.org"}, ) @@ -216,7 +210,6 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): "not@here.website", ], stream_pos=0, - _perf_factor=perf_factor, ), {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"}, ) @@ -224,11 +217,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # Query a subset of the entries mid-way through the stream. We should # only get back the subset. self.assertEqual( - cache.get_entities_changed( - ["bar@baz.net"], - stream_pos=2, - _perf_factor=perf_factor, - ), + cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"}, ) @@ -249,5 +238,5 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3) self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4) - # Unknown entities will return None - self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), None) + # Unknown entities will return the stream start position. + self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1) diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index 646fd2163e..b022cea0e8 100644 --- a/tests/util/test_stringutils.py +++ b/tests/util/test_stringutils.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py
index 30f0510c9f..ae2e9a334c 100644 --- a/tests/util/test_task_scheduler.py +++ b/tests/util/test_task_scheduler.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py
index 15575cc572..279798f1e5 100644 --- a/tests/util/test_threepids.py +++ b/tests/util/test_threepids.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2020 Dirk Klimpel # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 1bc5a7e267..7290958897 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2015, 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index 173a7cfaec..b75a32b73f 100644 --- a/tests/util/test_wheel_timer.py +++ b/tests/util/test_wheel_timer.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify diff --git a/tests/utils.py b/tests/utils.py
index 9fd26ef348..5798b04eef 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -1,7 +1,6 @@ # # This file is licensed under the Affero General Public License (AGPL) version 3. # -# Copyright 2014-2016 OpenMarket Ltd # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify @@ -21,20 +20,7 @@ import atexit import os -import signal -from types import FrameType, TracebackType -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, - overload, -) +from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload import attr from typing_extensions import Literal, ParamSpec @@ -134,11 +120,13 @@ def setupdb() -> None: @overload -def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: ... +def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: + ... @overload -def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: ... +def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: + ... def default_config( @@ -392,30 +380,3 @@ def checked_cast(type: Type[T], x: object) -> T: """ assert isinstance(x, type) return x - - -class TestTimeout(Exception): - pass - - -class test_timeout: - def __init__(self, seconds: int, error_message: Optional[str] = None) -> None: - if error_message is None: - error_message = "test timed out after {}s.".format(seconds) - self.seconds = seconds - self.error_message = error_message - - def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None: - raise TestTimeout(self.error_message) - - def __enter__(self) -> None: - signal.signal(signal.SIGALRM, self.handle_timeout) - signal.alarm(self.seconds) - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - signal.alarm(0)